From c71c8a87677a80431687a92fe7323156727046eb Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Thu, 22 Jun 2023 10:42:03 +0200 Subject: [PATCH 01/17] CSV file wrapper introduction --- .../internal/file_wrapper/csv_file_wrapper.py | 127 ++++++++++++++++++ .../file_wrapper/file_wrapper_type.py | 1 + 2 files changed, 128 insertions(+) create mode 100644 modyn/storage/internal/file_wrapper/csv_file_wrapper.py diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py new file mode 100644 index 000000000..ece2a2c76 --- /dev/null +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -0,0 +1,127 @@ +import csv +from typing import Iterator, Optional + +from modyn.storage.internal.file_wrapper.abstract_file_wrapper import AbstractFileWrapper +from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType +from modyn.storage.internal.filesystem_wrapper.abstract_filesystem_wrapper import AbstractFileSystemWrapper + + +class CsvFileWrapper(AbstractFileWrapper): + def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper: AbstractFileSystemWrapper): + super().__init__(file_path, file_wrapper_config, filesystem_wrapper) + + self.file_wrapper_type = FileWrapperType.CsvFileWrapper + + if "separator" in file_wrapper_config: + self.separator = file_wrapper_config["separator"] + else: + self.separator = ";" + + if "label_column" not in file_wrapper_config: + raise ValueError( + "Please specify the index of the column that contains the label. " + "Use None if no column contains the label" + ) + self.label_index = file_wrapper_config["label_column"] + + if "ignore_header" in file_wrapper_config: + self.ignore_header = file_wrapper_config["ignore_header"] + else: + self.ignore_header = False + + if "encoding" in file_wrapper_config: + self.encoding = file_wrapper_config["encoding"] + else: + self.encoding = "utf-8" + + self._validate_file_extension() + + def _validate_file_extension(self) -> None: + """Validates the file extension as csv + + Raises: + ValueError: File has wrong file extension + """ + if not self.file_path.endswith(".csv"): + raise ValueError("File has wrong file extension.") + + def get_sample(self, index: int) -> bytes: + samples = self._filter_rows_samples([index]) + assert len(samples) == 1 + + return samples[0] + + def get_samples(self, start: int, end: int) -> list[bytes]: + indices = list(range(start, end)) + return self.get_samples_from_indices(indices) + + def get_samples_from_indices(self, indices: list) -> list[bytes]: + samples = self._filter_rows_samples(indices) + assert len(samples) == len(indices) + + return samples + + def get_label(self, index: int) -> Optional[int]: + if self.label_index is None: + return None + + labels = self._filter_rows_labels([index]) + assert len(labels) == 0 + return labels[0] + + def get_all_labels(self) -> list[Optional[int]]: + reader = self._get_csv_reader() + + labels = [] + for row in reader: + only_label = row[self.label_index] + labels.append(only_label) + + return labels + + def get_number_of_samples(self) -> int: + reader = self._get_csv_reader() + + return sum(1 for _ in reader) + + def _get_csv_reader(self) -> Iterator: + data_file = self.filesystem_wrapper.get(self.file_path) + + # Convert bytes content to a string + data_file_str = data_file.decode(self.encoding) + + lines = data_file_str.split("\n") + + # Create a CSV reader + reader = csv.reader(lines, delimiter=self.separator) + + # skip the header if required + if self.ignore_header: + next(reader) + + return reader + + def _filter_rows_samples(self, indices: list[int]) -> list[bytes]: + reader = self._get_csv_reader() + + # Iterate over the rows and keep the selected ones + filtered_rows = [] + for i, row in enumerate(reader): + if i in indices: + # Remove the label, convert the row to bytes and append to the list + row_without_label = [col for j, col in enumerate(row) if j != self.label_index] + filtered_rows.append(bytes(self.separator.join(row_without_label), self.encoding)) + + return filtered_rows + + def _filter_rows_labels(self, indices: list[int]) -> list[int]: + reader = self._get_csv_reader() + + # Iterate over the rows and keep the selected ones + filtered_rows = [] + for i, row in enumerate(reader): + if i in indices: + only_label = row[self.label_index] + filtered_rows.append(only_label) + + return filtered_rows diff --git a/modyn/storage/internal/file_wrapper/file_wrapper_type.py b/modyn/storage/internal/file_wrapper/file_wrapper_type.py index 4a9f1c32b..758b99cbe 100644 --- a/modyn/storage/internal/file_wrapper/file_wrapper_type.py +++ b/modyn/storage/internal/file_wrapper/file_wrapper_type.py @@ -12,6 +12,7 @@ class FileWrapperType(Enum): SingleSampleFileWrapper = "single_sample_file_wrapper" # pylint: disable=invalid-name BinaryFileWrapper = "binary_file_wrapper" # pylint: disable=invalid-name + CsvFileWrapper = "csv_file_wrapper" # pylint: disable=invalid-name class InvalidFileWrapperTypeException(Exception): From b89111e1c3de09d7ceda2f33e2c642d23c0b7948 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Thu, 22 Jun 2023 10:56:04 +0200 Subject: [PATCH 02/17] Get all labels when the label is not present --- .../storage/internal/file_wrapper/csv_file_wrapper.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py index ece2a2c76..af2636f81 100644 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -24,10 +24,10 @@ def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper ) self.label_index = file_wrapper_config["label_column"] - if "ignore_header" in file_wrapper_config: - self.ignore_header = file_wrapper_config["ignore_header"] + if "ignore_first_line" in file_wrapper_config: + self.ignore_first_line = file_wrapper_config["ignore_header"] else: - self.ignore_header = False + self.ignore_first_line = False if "encoding" in file_wrapper_config: self.encoding = file_wrapper_config["encoding"] @@ -72,6 +72,9 @@ def get_label(self, index: int) -> Optional[int]: def get_all_labels(self) -> list[Optional[int]]: reader = self._get_csv_reader() + if self.label_index is None: + return [None] * self.get_number_of_samples() + labels = [] for row in reader: only_label = row[self.label_index] @@ -96,7 +99,7 @@ def _get_csv_reader(self) -> Iterator: reader = csv.reader(lines, delimiter=self.separator) # skip the header if required - if self.ignore_header: + if self.ignore_first_line: next(reader) return reader From e8f6b1d6e43724d594b23a782db558a50c45f591 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Thu, 22 Jun 2023 11:02:31 +0200 Subject: [PATCH 03/17] Fake delete_samples --- modyn/storage/internal/file_wrapper/csv_file_wrapper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py index af2636f81..f30481b45 100644 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -128,3 +128,6 @@ def _filter_rows_labels(self, indices: list[int]) -> list[int]: filtered_rows.append(only_label) return filtered_rows + + def delete_samples(self, indices: list) -> None: + return From 2261fcb27fc64d1d24121984ea652c8e7182b9e9 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Thu, 22 Jun 2023 11:12:05 +0200 Subject: [PATCH 04/17] Assertions changed to index errors --- .../internal/file_wrapper/csv_file_wrapper.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py index f30481b45..c33fa6663 100644 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -17,15 +17,15 @@ def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper else: self.separator = ";" - if "label_column" not in file_wrapper_config: + if "label_index" not in file_wrapper_config: raise ValueError( "Please specify the index of the column that contains the label. " "Use None if no column contains the label" ) - self.label_index = file_wrapper_config["label_column"] + self.label_index = file_wrapper_config["label_index"] if "ignore_first_line" in file_wrapper_config: - self.ignore_first_line = file_wrapper_config["ignore_header"] + self.ignore_first_line = file_wrapper_config["ignore_first_line"] else: self.ignore_first_line = False @@ -47,7 +47,9 @@ def _validate_file_extension(self) -> None: def get_sample(self, index: int) -> bytes: samples = self._filter_rows_samples([index]) - assert len(samples) == 1 + + if len(samples) != 1: + raise IndexError("Invalid index") return samples[0] @@ -57,7 +59,9 @@ def get_samples(self, start: int, end: int) -> list[bytes]: def get_samples_from_indices(self, indices: list) -> list[bytes]: samples = self._filter_rows_samples(indices) - assert len(samples) == len(indices) + + if len(samples) != len(indices): + raise IndexError("At least one index is invalid.") return samples @@ -66,7 +70,10 @@ def get_label(self, index: int) -> Optional[int]: return None labels = self._filter_rows_labels([index]) - assert len(labels) == 0 + + if len(labels) != 1: + raise IndexError("Invalid index.") + return labels[0] def get_all_labels(self) -> list[Optional[int]]: From 35033a15deaccec1ab2f1ab99e9526d8b4a56d52 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Thu, 22 Jun 2023 11:31:28 +0200 Subject: [PATCH 05/17] Fixed typing. Convert labels to integer. --- .../internal/file_wrapper/csv_file_wrapper.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py index c33fa6663..b26f822e5 100644 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -79,13 +79,15 @@ def get_label(self, index: int) -> Optional[int]: def get_all_labels(self) -> list[Optional[int]]: reader = self._get_csv_reader() + labels: list[Optional[int]] = [] + if self.label_index is None: return [None] * self.get_number_of_samples() - labels = [] for row in reader: - only_label = row[self.label_index] - labels.append(only_label) + # labels are integer in modyn + int_label = int(row[self.label_index]) + labels.append(int_label) return labels @@ -131,8 +133,9 @@ def _filter_rows_labels(self, indices: list[int]) -> list[int]: filtered_rows = [] for i, row in enumerate(reader): if i in indices: - only_label = row[self.label_index] - filtered_rows.append(only_label) + # labels are integer in modyn + int_label = int(row[self.label_index]) + filtered_rows.append(int_label) return filtered_rows From 3e587a6401d6eb7975935766ee07ee9fbc4fd359 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Thu, 22 Jun 2023 11:31:50 +0200 Subject: [PATCH 06/17] Basic tests --- .../file_wrapper/test_csv_file_wrapper.py | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py diff --git a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py new file mode 100644 index 000000000..a852ffb54 --- /dev/null +++ b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py @@ -0,0 +1,134 @@ +import os +import pathlib +import shutil + +import pytest +from modyn.storage.internal.file_wrapper.csv_file_wrapper import CsvFileWrapper +from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType + +TMP_DIR = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn") +FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.csv") +FILE_DATA = b"a;b;c;d;12\ne;f;g;h;76" +INVALID_FILE_EXTENSION_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.txt") +FILE_WRAPPER_CONFIG = { + "ignore_first_line": False, + "label_index": 4, + "separator": ";", +} + + +def setup(): + os.makedirs(TMP_DIR, exist_ok=True) + + with open(FILE_PATH, "wb") as file: + file.write(FILE_DATA) + + +def teardown(): + os.remove(FILE_PATH) + shutil.rmtree(TMP_DIR) + + +class MockFileSystemWrapper: + def __init__(self, file_path): + self.file_path = file_path + + def get(self, file_path): + with open(file_path, "rb") as file: + return file.read() + + def get_size(self, path): + return os.path.getsize(path) + + +def test_init(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + assert file_wrapper.file_path == FILE_PATH + assert file_wrapper.file_wrapper_type == FileWrapperType.CsvFileWrapper + assert file_wrapper.encoding == "utf-8" + assert file_wrapper.label_index == 4 + assert not file_wrapper.ignore_first_line + assert file_wrapper.separator == ";" + + +def test_init_with_invalid_file_extension(): + with pytest.raises(ValueError): + CsvFileWrapper( + INVALID_FILE_EXTENSION_PATH, + FILE_WRAPPER_CONFIG, + MockFileSystemWrapper(INVALID_FILE_EXTENSION_PATH), + ) + + +def test_get_number_of_samples(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + assert file_wrapper.get_number_of_samples() == 2 + + # check if the first line is correctly ignored + file_wrapper.ignore_first_line = True + assert file_wrapper.get_number_of_samples() == 1 + + +def test_get_sample(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + sample = file_wrapper.get_sample(0) + assert sample == b"a;b;c;d" + + sample = file_wrapper.get_sample(1) + assert sample == b"e;f;g;h" + + +def test_get_sample_with_invalid_index(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + with pytest.raises(IndexError): + file_wrapper.get_sample(10) + + +def test_get_label(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + label = file_wrapper.get_label(0) + assert label == 12 + + label = file_wrapper.get_label(1) + assert label == 76 + + with pytest.raises(IndexError): + file_wrapper.get_label(2) + + +def test_get_all_labels(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + assert file_wrapper.get_all_labels() == [12, 76] + + +def test_get_label_with_invalid_index(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + with pytest.raises(IndexError): + file_wrapper.get_label(10) + + +def test_get_samples(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + samples = file_wrapper.get_samples(0, 1) + assert len(samples) == 1 + assert samples[0] == b"a;b;c;d" + + samples = file_wrapper.get_samples(0, 2) + assert len(samples) == 2 + assert samples[0] == b"a;b;c;d" + assert samples[1] == b"e;f;g;h" + + +def test_get_samples_with_invalid_index(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + with pytest.raises(IndexError): + file_wrapper.get_samples(0, 5) + + with pytest.raises(IndexError): + file_wrapper.get_samples(3, 4) + + +def test_get_samples_from_indices_with_invalid_indices(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + with pytest.raises(IndexError): + file_wrapper.get_samples_from_indices([-2, 1]) From 10c4da0f392c19c97fc5f374e76d3286f77f51d6 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Thu, 22 Jun 2023 11:51:53 +0200 Subject: [PATCH 07/17] Validate file content --- .../internal/file_wrapper/csv_file_wrapper.py | 19 ++++++++++ .../file_wrapper/test_csv_file_wrapper.py | 36 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py index b26f822e5..815c9a415 100644 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -35,6 +35,7 @@ def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper self.encoding = "utf-8" self._validate_file_extension() + self._validate_file_content() def _validate_file_extension(self) -> None: """Validates the file extension as csv @@ -45,6 +46,24 @@ def _validate_file_extension(self) -> None: if not self.file_path.endswith(".csv"): raise ValueError("File has wrong file extension.") + def _validate_file_content(self): + + reader = self._get_csv_reader() + + number_of_columns = [] + + for i, row in enumerate(reader): + + number_of_columns.append(len(row)) + if self.label_index is not None: + if not 0 <= self.label_index < len(row): + raise ValueError("Label index outside row boundary") + if not row[self.label_index].isnumeric(): + raise ValueError("The label must be an integer") + + if not len(set(number_of_columns)) == 1: + raise ValueError("Some rows have different width. " + f"This is the number of columns row by row {number_of_columns}") def get_sample(self, index: int) -> bytes: samples = self._filter_rows_samples([index]) diff --git a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py index a852ffb54..3df9a8dd4 100644 --- a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py +++ b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py @@ -8,6 +8,7 @@ TMP_DIR = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn") FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.csv") +WRONG_FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "wrong_test.csv") FILE_DATA = b"a;b;c;d;12\ne;f;g;h;76" INVALID_FILE_EXTENSION_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.txt") FILE_WRAPPER_CONFIG = { @@ -132,3 +133,38 @@ def test_get_samples_from_indices_with_invalid_indices(): file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) with pytest.raises(IndexError): file_wrapper.get_samples_from_indices([-2, 1]) + +def write_to_file(wrong_data): + with open(WRONG_FILE_PATH, "wb") as file: + file.write(wrong_data) + +def test_invalid_file_content(): + # extra field in one row + wrong_data = b"a;b;c;d;12;e\ne;f;g;h;76" + write_to_file(wrong_data) + + with pytest.raises(ValueError): + _ = CsvFileWrapper(WRONG_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) + + # label column outside boundary + wrong_data = b"a;b;c;12\ne;f;g;76" + write_to_file(wrong_data) + + with pytest.raises(ValueError): + _ = CsvFileWrapper(WRONG_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) + + # str label column + wrong_data = b"a;b;c;d;e;12\ne;f;g;h;h;76" + write_to_file(wrong_data) + with pytest.raises(ValueError): + _ = CsvFileWrapper(WRONG_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) + + # just one str in label + wrong_data = b"a;b;c;d;88;12\ne;f;g;h;h;76" + write_to_file(wrong_data) + with pytest.raises(ValueError): + _ = CsvFileWrapper(WRONG_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) + + + + From 8b1e841108bab15d049ac6d5bfc91bbd7eec9673 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Thu, 22 Jun 2023 12:00:25 +0200 Subject: [PATCH 08/17] Test TSV vs CSV (different separator) --- .../file_wrapper/test_csv_file_wrapper.py | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py index 3df9a8dd4..d471437b5 100644 --- a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py +++ b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py @@ -134,9 +134,9 @@ def test_get_samples_from_indices_with_invalid_indices(): with pytest.raises(IndexError): file_wrapper.get_samples_from_indices([-2, 1]) -def write_to_file(wrong_data): +def write_to_file(data): with open(WRONG_FILE_PATH, "wb") as file: - file.write(wrong_data) + file.write(data) def test_invalid_file_content(): # extra field in one row @@ -165,6 +165,35 @@ def test_invalid_file_content(): with pytest.raises(ValueError): _ = CsvFileWrapper(WRONG_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) +def test_different_separator(): + TSV_FILE_DATA = b"a\tb\tc\td\t12\ne\tf\tg\th\t76" + + TSV_FILE_WRAPPER_CONFIG = { + "ignore_first_line": False, + "label_index": 4, + "separator": "\t", + } + + write_to_file(TSV_FILE_DATA) + tsv_file_wrapper = CsvFileWrapper(WRONG_FILE_PATH, TSV_FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) + csv_file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + + assert tsv_file_wrapper.get_number_of_samples() == csv_file_wrapper.get_number_of_samples() + + + assert tsv_file_wrapper.get_sample(0) == b'a\tb\tc\td' + assert tsv_file_wrapper.get_sample(1) == b'e\tf\tg\th' + + tsv_samples = tsv_file_wrapper.get_samples(0,2) + csv_samples = csv_file_wrapper.get_samples(0,2) + + tsv_samples = [ sample.decode("utf-8").split("\t") for sample in tsv_samples] + csv_samples = [ sample.decode("utf-8").split(";") for sample in csv_samples] + assert tsv_samples == csv_samples + + assert tsv_file_wrapper.get_label(0) == csv_file_wrapper.get_label(0) + assert tsv_file_wrapper.get_label(1) == csv_file_wrapper.get_label(1) + From b5c1555f7442bc1b0dffdf470e22a3bb73ee4522 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Thu, 22 Jun 2023 12:22:47 +0200 Subject: [PATCH 09/17] Basic documentation --- .../internal/file_wrapper/csv_file_wrapper.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py index 815c9a415..2ca75f0d2 100644 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -24,6 +24,7 @@ def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper ) self.label_index = file_wrapper_config["label_index"] + # the first line might contain the header, which is useless and must not be returned. if "ignore_first_line" in file_wrapper_config: self.ignore_first_line = file_wrapper_config["ignore_first_line"] else: @@ -34,7 +35,9 @@ def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper else: self.encoding = "utf-8" + # check that the file is actually a CSV self._validate_file_extension() + self._validate_file_content() def _validate_file_extension(self) -> None: @@ -47,6 +50,14 @@ def _validate_file_extension(self) -> None: raise ValueError("File has wrong file extension.") def _validate_file_content(self): + """ + Performs the following checks: + - specified label column is castable to integer + - each row has the label_index_column + - each row has the same width + + Raises a ValueError if a condition is not met + """ reader = self._get_csv_reader() @@ -58,13 +69,14 @@ def _validate_file_content(self): if self.label_index is not None: if not 0 <= self.label_index < len(row): raise ValueError("Label index outside row boundary") - if not row[self.label_index].isnumeric(): + if not row[self.label_index].isnumeric(): #returns true iff all the characters are numbers raise ValueError("The label must be an integer") if not len(set(number_of_columns)) == 1: raise ValueError("Some rows have different width. " f"This is the number of columns row by row {number_of_columns}") def get_sample(self, index: int) -> bytes: + samples = self._filter_rows_samples([index]) if len(samples) != 1: @@ -116,6 +128,11 @@ def get_number_of_samples(self) -> int: return sum(1 for _ in reader) def _get_csv_reader(self) -> Iterator: + """ + Receives the bytes from the file_system_wrapper and creates a csv.reader out of it. + Returns: + csv.reader + """ data_file = self.filesystem_wrapper.get(self.file_path) # Convert bytes content to a string @@ -133,6 +150,15 @@ def _get_csv_reader(self) -> Iterator: return reader def _filter_rows_samples(self, indices: list[int]) -> list[bytes]: + """ + Filters the selected rows and removes the label column + Args: + indices: list of rows that must be kept + + Returns: + list of byte-encoded rows + + """ reader = self._get_csv_reader() # Iterate over the rows and keep the selected ones @@ -141,11 +167,21 @@ def _filter_rows_samples(self, indices: list[int]) -> list[bytes]: if i in indices: # Remove the label, convert the row to bytes and append to the list row_without_label = [col for j, col in enumerate(row) if j != self.label_index] + # the row is transformed in a similar csv using the same separator and then transformed to bytes filtered_rows.append(bytes(self.separator.join(row_without_label), self.encoding)) return filtered_rows def _filter_rows_labels(self, indices: list[int]) -> list[int]: + """ + Filters the selected rows and extracts the label column + Args: + indices: list of rows that must be kept + + Returns: + list of labels + + """ reader = self._get_csv_reader() # Iterate over the rows and keep the selected ones From 487f96c1f11b597cd97d134ab5b720e2737ab27b Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Thu, 22 Jun 2023 14:14:02 +0200 Subject: [PATCH 10/17] Added Pipeline Config --- modyn/config/schema/modyn_config_schema.yaml | 32 ++++++++++++++++--- .../internal/file_wrapper/csv_file_wrapper.py | 16 +++++----- .../file_wrapper/test_csv_file_wrapper.py | 28 ++++++++-------- 3 files changed, 48 insertions(+), 28 deletions(-) diff --git a/modyn/config/schema/modyn_config_schema.yaml b/modyn/config/schema/modyn_config_schema.yaml index bb25987cd..81004ab97 100644 --- a/modyn/config/schema/modyn_config_schema.yaml +++ b/modyn/config/schema/modyn_config_schema.yaml @@ -37,7 +37,8 @@ properties: sample_batch_size: type: number description: | - The size of a batch when requesting new samples from storage. All new samples are returned, however, to reduce the size of a single answer the keys are batched in sizes of `sample_batch_size`. + The size of a batch when requesting new samples from storage. All new samples are returned, however, to reduce + the size of a single answer the keys are batched in sizes of `sample_batch_size`. sample_dbinsertion_batchsize: type: number description: | @@ -49,7 +50,9 @@ properties: sample_table_unlogged: type: boolean description: | - This configures whether the table storing all samples is UNLOGGED (= high performance) or crash resilient. Defaults to True. For datasets with many samples (such as Criteo), this is recommended for highest insertion performance. In other scenarios, this might not be necessary. + This configures whether the table storing all samples is UNLOGGED (= high performance) or crash resilient. + Defaults to True. For datasets with many samples (such as Criteo), this is recommended for highest insertion performance. + In other scenarios, this might not be necessary. force_fallback_insert: type: boolean description: | @@ -99,15 +102,34 @@ properties: record_size: type: number description: | - The size of each full record in bytes (label + features) for a binary file wrapper. + [BinaryFileWrapper] The size of each full record in bytes (label + features). label_size: type: number description: | - The size of the label field in bytes for a binary file wrapper. + [BinaryFileWrapper] The size of the label field in bytes for a binary file wrapper. byteorder: type: string description: | - The byteorder when reading an integer from multibyte data in a binary file. Should either be "big" or "little". + [BinaryFileWrapper] The byteorder when reading an integer from multibyte data in a binary file. + Should either be "big" or "little". + separator: + type: string + description: | + [CsvFileWrapper] The separator used in the CSV file. The default is ";" + label_index: + type: number + description: | + [CsvFileWrapper] Column index of the label. + For example, if the columns are "width", "height", "age", "label" you should set label_index to 3. + ignore_first_line: + type: boolean + description: | + [CsvFileWrapper] If the first line is the table header, you can skip it setting this parameter to True. + Default is False. + encoding: + type: string + description: | + [CsvFileWrapper] Encoding of your file. Default is utf-8 ignore_last_timestamp: type: boolean description: | diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py index 2ca75f0d2..12ae49ead 100644 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -49,7 +49,7 @@ def _validate_file_extension(self) -> None: if not self.file_path.endswith(".csv"): raise ValueError("File has wrong file extension.") - def _validate_file_content(self): + def _validate_file_content(self) -> None: """ Performs the following checks: - specified label column is castable to integer @@ -63,20 +63,20 @@ def _validate_file_content(self): number_of_columns = [] - for i, row in enumerate(reader): - + for row in reader: number_of_columns.append(len(row)) if self.label_index is not None: if not 0 <= self.label_index < len(row): raise ValueError("Label index outside row boundary") - if not row[self.label_index].isnumeric(): #returns true iff all the characters are numbers + if not row[self.label_index].isnumeric(): # returns true iff all the characters are numbers raise ValueError("The label must be an integer") - if not len(set(number_of_columns)) == 1: - raise ValueError("Some rows have different width. " - f"This is the number of columns row by row {number_of_columns}") - def get_sample(self, index: int) -> bytes: + if len(set(number_of_columns)) != 1: + raise ValueError( + "Some rows have different width. " f"This is the number of columns row by row {number_of_columns}" + ) + def get_sample(self, index: int) -> bytes: samples = self._filter_rows_samples([index]) if len(samples) != 1: diff --git a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py index d471437b5..a0aa153da 100644 --- a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py +++ b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py @@ -134,10 +134,12 @@ def test_get_samples_from_indices_with_invalid_indices(): with pytest.raises(IndexError): file_wrapper.get_samples_from_indices([-2, 1]) + def write_to_file(data): with open(WRONG_FILE_PATH, "wb") as file: file.write(data) + def test_invalid_file_content(): # extra field in one row wrong_data = b"a;b;c;d;12;e\ne;f;g;h;76" @@ -165,35 +167,31 @@ def test_invalid_file_content(): with pytest.raises(ValueError): _ = CsvFileWrapper(WRONG_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) + def test_different_separator(): - TSV_FILE_DATA = b"a\tb\tc\td\t12\ne\tf\tg\th\t76" + tsv_file_data = b"a\tb\tc\td\t12\ne\tf\tg\th\t76" - TSV_FILE_WRAPPER_CONFIG = { + tsv_file_wrapper_config = { "ignore_first_line": False, "label_index": 4, "separator": "\t", } - write_to_file(TSV_FILE_DATA) - tsv_file_wrapper = CsvFileWrapper(WRONG_FILE_PATH, TSV_FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) + write_to_file(tsv_file_data) + tsv_file_wrapper = CsvFileWrapper(WRONG_FILE_PATH, tsv_file_wrapper_config, MockFileSystemWrapper(WRONG_FILE_PATH)) csv_file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) assert tsv_file_wrapper.get_number_of_samples() == csv_file_wrapper.get_number_of_samples() + assert tsv_file_wrapper.get_sample(0) == b"a\tb\tc\td" + assert tsv_file_wrapper.get_sample(1) == b"e\tf\tg\th" - assert tsv_file_wrapper.get_sample(0) == b'a\tb\tc\td' - assert tsv_file_wrapper.get_sample(1) == b'e\tf\tg\th' - - tsv_samples = tsv_file_wrapper.get_samples(0,2) - csv_samples = csv_file_wrapper.get_samples(0,2) + tsv_samples = tsv_file_wrapper.get_samples(0, 2) + csv_samples = csv_file_wrapper.get_samples(0, 2) - tsv_samples = [ sample.decode("utf-8").split("\t") for sample in tsv_samples] - csv_samples = [ sample.decode("utf-8").split(";") for sample in csv_samples] + tsv_samples = [sample.decode("utf-8").split("\t") for sample in tsv_samples] + csv_samples = [sample.decode("utf-8").split(";") for sample in csv_samples] assert tsv_samples == csv_samples assert tsv_file_wrapper.get_label(0) == csv_file_wrapper.get_label(0) assert tsv_file_wrapper.get_label(1) == csv_file_wrapper.get_label(1) - - - - From f7665d00d4c97364fe0c4d144980c7384ba67947 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Tue, 27 Jun 2023 12:17:18 +0200 Subject: [PATCH 11/17] Changed default separator to comma --- modyn/config/schema/modyn_config_schema.yaml | 4 ++-- modyn/storage/internal/file_wrapper/csv_file_wrapper.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modyn/config/schema/modyn_config_schema.yaml b/modyn/config/schema/modyn_config_schema.yaml index 0e7303cc4..43c6e157f 100644 --- a/modyn/config/schema/modyn_config_schema.yaml +++ b/modyn/config/schema/modyn_config_schema.yaml @@ -115,7 +115,7 @@ properties: separator: type: string description: | - [CsvFileWrapper] The separator used in the CSV file. The default is ";" + [CsvFileWrapper] The separator used in the CSV file. The default is ",". label_index: type: number description: | @@ -129,7 +129,7 @@ properties: encoding: type: string description: | - [CsvFileWrapper] Encoding of your file. Default is utf-8 + [CsvFileWrapper] Encoding of the CSV file. Default is utf-8. ignore_last_timestamp: type: boolean description: | diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py index 12ae49ead..28eac9961 100644 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -15,7 +15,7 @@ def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper if "separator" in file_wrapper_config: self.separator = file_wrapper_config["separator"] else: - self.separator = ";" + self.separator = "," if "label_index" not in file_wrapper_config: raise ValueError( From 220b93b0014a54c33e0e900751ffbdb08620b7e4 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Tue, 27 Jun 2023 15:02:24 +0200 Subject: [PATCH 12/17] Optional validation --- modyn/config/schema/modyn_config_schema.yaml | 5 +++++ modyn/storage/internal/file_wrapper/csv_file_wrapper.py | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/modyn/config/schema/modyn_config_schema.yaml b/modyn/config/schema/modyn_config_schema.yaml index 43c6e157f..10567657a 100644 --- a/modyn/config/schema/modyn_config_schema.yaml +++ b/modyn/config/schema/modyn_config_schema.yaml @@ -130,6 +130,11 @@ properties: type: string description: | [CsvFileWrapper] Encoding of the CSV file. Default is utf-8. + validate_file_content: + type: boolean + description: | + [CsvFileWrapper] Whether to validate the file content before inserting the data. It checks that it + is a csv, that all rows are the same size and that the 'label' column exists. Default is True ignore_last_timestamp: type: boolean description: | diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py index 28eac9961..bb43c855e 100644 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -38,7 +38,11 @@ def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper # check that the file is actually a CSV self._validate_file_extension() - self._validate_file_content() + # do not validate the content only if "validate_file_content" is explicitly set to False + if ("validate_file_content" not in file_wrapper_config) or ( + "validate_file_content" in file_wrapper_config and file_wrapper_config["validate_file_content"] + ): + self._validate_file_content() def _validate_file_extension(self) -> None: """Validates the file extension as csv From 9e6afa1cb099b7f204d7caf76ba28ec470aaf0b6 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Tue, 27 Jun 2023 15:14:00 +0200 Subject: [PATCH 13/17] Enforced label_index --- .../internal/file_wrapper/csv_file_wrapper.py | 28 +++++-------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py index bb43c855e..aaaa94ef5 100644 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -1,5 +1,5 @@ import csv -from typing import Iterator, Optional +from typing import Iterator from modyn.storage.internal.file_wrapper.abstract_file_wrapper import AbstractFileWrapper from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType @@ -18,10 +18,9 @@ def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper self.separator = "," if "label_index" not in file_wrapper_config: - raise ValueError( - "Please specify the index of the column that contains the label. " - "Use None if no column contains the label" - ) + raise ValueError("Please specify the index of the column that contains the label. ") + if not isinstance(file_wrapper_config["label_index"], int) or file_wrapper_config["label_index"] < 0: + raise ValueError("The label_index must be a positive integer.") self.label_index = file_wrapper_config["label_index"] # the first line might contain the header, which is useless and must not be returned. @@ -100,10 +99,7 @@ def get_samples_from_indices(self, indices: list) -> list[bytes]: return samples - def get_label(self, index: int) -> Optional[int]: - if self.label_index is None: - return None - + def get_label(self, index: int) -> int: labels = self._filter_rows_labels([index]) if len(labels) != 1: @@ -111,19 +107,9 @@ def get_label(self, index: int) -> Optional[int]: return labels[0] - def get_all_labels(self) -> list[Optional[int]]: + def get_all_labels(self) -> list[int]: reader = self._get_csv_reader() - - labels: list[Optional[int]] = [] - - if self.label_index is None: - return [None] * self.get_number_of_samples() - - for row in reader: - # labels are integer in modyn - int_label = int(row[self.label_index]) - labels.append(int_label) - + labels = [int(row[self.label_index]) for row in reader] return labels def get_number_of_samples(self) -> int: From fabe355d09cfe9a9d16cc684733f3efe20f695b8 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Tue, 27 Jun 2023 16:06:47 +0200 Subject: [PATCH 14/17] Ouut of order indexes and tests --- .../internal/file_wrapper/csv_file_wrapper.py | 45 +++++---- .../file_wrapper/test_csv_file_wrapper.py | 95 +++++++++++++++++-- 2 files changed, 113 insertions(+), 27 deletions(-) diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py index aaaa94ef5..355fb5918 100644 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -1,5 +1,5 @@ import csv -from typing import Iterator +from typing import Iterator, Optional from modyn.storage.internal.file_wrapper.abstract_file_wrapper import AbstractFileWrapper from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType @@ -68,11 +68,10 @@ def _validate_file_content(self) -> None: for row in reader: number_of_columns.append(len(row)) - if self.label_index is not None: - if not 0 <= self.label_index < len(row): - raise ValueError("Label index outside row boundary") - if not row[self.label_index].isnumeric(): # returns true iff all the characters are numbers - raise ValueError("The label must be an integer") + if not 0 <= self.label_index < len(row): + raise ValueError("Label index outside row boundary") + if not row[self.label_index].isnumeric(): # returns true iff all the characters are numbers + raise ValueError("The label must be an integer") if len(set(number_of_columns)) != 1: raise ValueError( @@ -92,12 +91,7 @@ def get_samples(self, start: int, end: int) -> list[bytes]: return self.get_samples_from_indices(indices) def get_samples_from_indices(self, indices: list) -> list[bytes]: - samples = self._filter_rows_samples(indices) - - if len(samples) != len(indices): - raise IndexError("At least one index is invalid.") - - return samples + return self._filter_rows_samples(indices) def get_label(self, index: int) -> int: labels = self._filter_rows_labels([index]) @@ -114,7 +108,6 @@ def get_all_labels(self) -> list[int]: def get_number_of_samples(self) -> int: reader = self._get_csv_reader() - return sum(1 for _ in reader) def _get_csv_reader(self) -> Iterator: @@ -149,18 +142,24 @@ def _filter_rows_samples(self, indices: list[int]) -> list[bytes]: list of byte-encoded rows """ + assert len(indices) == len(set(indices)), "An index is required more than once." reader = self._get_csv_reader() # Iterate over the rows and keep the selected ones - filtered_rows = [] + filtered_rows: list[Optional[bytes]] = [None] * len(indices) for i, row in enumerate(reader): if i in indices: # Remove the label, convert the row to bytes and append to the list row_without_label = [col for j, col in enumerate(row) if j != self.label_index] # the row is transformed in a similar csv using the same separator and then transformed to bytes - filtered_rows.append(bytes(self.separator.join(row_without_label), self.encoding)) + filtered_rows[indices.index(i)] = bytes(self.separator.join(row_without_label), self.encoding) - return filtered_rows + if sum(1 for el in filtered_rows if el is None) != 0: + raise IndexError("At least one index is invalid") + + # Here mypy complains that filtered_rows is a list of list[Optional[bytes]], + # that can't happen given the above exception + return filtered_rows # type: ignore def _filter_rows_labels(self, indices: list[int]) -> list[int]: """ @@ -172,17 +171,23 @@ def _filter_rows_labels(self, indices: list[int]) -> list[int]: list of labels """ + assert len(indices) == len(set(indices)), "An index is required more than once." reader = self._get_csv_reader() # Iterate over the rows and keep the selected ones - filtered_rows = [] + filtered_rows: list[Optional[int]] = [None] * len(indices) for i, row in enumerate(reader): if i in indices: # labels are integer in modyn int_label = int(row[self.label_index]) - filtered_rows.append(int_label) + filtered_rows[indices.index(i)] = int_label + + if sum(1 for el in filtered_rows if el is None) != 0: + raise IndexError("At least one index is invalid") - return filtered_rows + # Here mypy complains that filtered_rows is a list of list[Optional[bytes]], + # that can't happen given the above exception + return filtered_rows # type: ignore def delete_samples(self, indices: list) -> None: - return + pass diff --git a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py index a0aa153da..b345b574e 100644 --- a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py +++ b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py @@ -8,7 +8,7 @@ TMP_DIR = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn") FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.csv") -WRONG_FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "wrong_test.csv") +CUSTOM_FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "wrong_test.csv") FILE_DATA = b"a;b;c;d;12\ne;f;g;h;76" INVALID_FILE_EXTENSION_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.txt") FILE_WRAPPER_CONFIG = { @@ -136,7 +136,7 @@ def test_get_samples_from_indices_with_invalid_indices(): def write_to_file(data): - with open(WRONG_FILE_PATH, "wb") as file: + with open(CUSTOM_FILE_PATH, "wb") as file: file.write(data) @@ -146,26 +146,63 @@ def test_invalid_file_content(): write_to_file(wrong_data) with pytest.raises(ValueError): - _ = CsvFileWrapper(WRONG_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) + _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) # label column outside boundary wrong_data = b"a;b;c;12\ne;f;g;76" write_to_file(wrong_data) with pytest.raises(ValueError): - _ = CsvFileWrapper(WRONG_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) + _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) # str label column wrong_data = b"a;b;c;d;e;12\ne;f;g;h;h;76" write_to_file(wrong_data) with pytest.raises(ValueError): - _ = CsvFileWrapper(WRONG_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) + _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) # just one str in label wrong_data = b"a;b;c;d;88;12\ne;f;g;h;h;76" write_to_file(wrong_data) with pytest.raises(ValueError): - _ = CsvFileWrapper(WRONG_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(WRONG_FILE_PATH)) + _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + +def test_invalid_file_content_skip_validation(): + # extra field in one row + wrong_data = b"a;b;c;d;12;e\ne;f;g;h;76" + write_to_file(wrong_data) + + config = FILE_WRAPPER_CONFIG.copy() + config["validate_file_content"] = False + + _ = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + # label column outside boundary + wrong_data = b"a;b;c;12\ne;f;g;76" + write_to_file(wrong_data) + + file_wrapper = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + with pytest.raises(IndexError): # fails since index > number of columns + file_wrapper.get_label(1) + + # str label column + wrong_data = b"a;b;c;d;e;12\ne;f;g;h;h;76" + write_to_file(wrong_data) + CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + with pytest.raises(ValueError): # fails to convert to integer + file_wrapper.get_label(1) + + # just one str in label + wrong_data = b"a;b;c;d;88;12\ne;f;g;h;h;76" + write_to_file(wrong_data) + CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + file_wrapper.get_label(0) # does not fail since row 0 is ok + with pytest.raises(ValueError): # fails to convert to integer + file_wrapper.get_label(1) def test_different_separator(): @@ -178,7 +215,9 @@ def test_different_separator(): } write_to_file(tsv_file_data) - tsv_file_wrapper = CsvFileWrapper(WRONG_FILE_PATH, tsv_file_wrapper_config, MockFileSystemWrapper(WRONG_FILE_PATH)) + tsv_file_wrapper = CsvFileWrapper( + CUSTOM_FILE_PATH, tsv_file_wrapper_config, MockFileSystemWrapper(CUSTOM_FILE_PATH) + ) csv_file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) assert tsv_file_wrapper.get_number_of_samples() == csv_file_wrapper.get_number_of_samples() @@ -195,3 +234,45 @@ def test_different_separator(): assert tsv_file_wrapper.get_label(0) == csv_file_wrapper.get_label(0) assert tsv_file_wrapper.get_label(1) == csv_file_wrapper.get_label(1) + + +def test_out_of_order_sequence(): + content = b"A1;B1;C1;1\nA2;B2;C2;2\nA3;B3;C3;3\nA4;B4;C4;4\nA5;B5;C5;5" + converted = [b"A1;B1;C1", b"A2;B2;C2", b"A3;B3;C3", b"A4;B4;C4", b"A5;B5;C5"] + write_to_file(content) + config = { + "ignore_first_line": False, + "label_index": 3, + "separator": ";", + } + file_wrapper = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + # samples + assert file_wrapper.get_samples_from_indices([2, 1]) == [converted[2], converted[1]] + assert file_wrapper.get_samples_from_indices([3, 2, 1]) == [converted[3], converted[2], converted[1]] + assert file_wrapper.get_samples_from_indices([3, 2, 4, 1]) == [ + converted[3], + converted[2], + converted[4], + converted[1], + ] + + +def test_duplicate_request(): + content = b"A1;B1;C1;1\nA2;B2;C2;2\nA3;B3;C3;3\nA4;B4;C4;4\nA5;B5;C5;5" + write_to_file(content) + config = { + "ignore_first_line": False, + "label_index": 3, + "separator": ";", + } + file_wrapper = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + with pytest.raises(AssertionError): + file_wrapper.get_samples_from_indices([1, 1]) + + with pytest.raises(AssertionError): + file_wrapper.get_samples_from_indices([1, 1, 3]) + + with pytest.raises(AssertionError): + file_wrapper.get_samples_from_indices([1, 1, 13]) From 7a37bcade26d9b8ce8ed8bc3f141b1dc015de5b3 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Tue, 27 Jun 2023 17:47:07 +0200 Subject: [PATCH 15/17] Integration test CSV format --- integrationtests/run.sh | 1 + .../storage/integrationtest_storage_csv.py | 163 ++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 integrationtests/storage/integrationtest_storage_csv.py diff --git a/integrationtests/run.sh b/integrationtests/run.sh index e8264f76f..6826ae1cb 100755 --- a/integrationtests/run.sh +++ b/integrationtests/run.sh @@ -10,6 +10,7 @@ python $SCRIPT_DIR/test_docker_compose.py python $SCRIPT_DIR/test_ftp_connections.py echo "Running storage integration tests" python $SCRIPT_DIR/storage/integrationtest_storage.py +python $SCRIPT_DIR/storage/integrationtest_storage_csv.py echo "Running selector integration tests" python $SCRIPT_DIR/selector/integrationtest_selector.py echo "Running model storage integration tests" diff --git a/integrationtests/storage/integrationtest_storage_csv.py b/integrationtests/storage/integrationtest_storage_csv.py new file mode 100644 index 000000000..6d5421fd7 --- /dev/null +++ b/integrationtests/storage/integrationtest_storage_csv.py @@ -0,0 +1,163 @@ +import io +import json +import os +import pathlib +import random +import time +from typing import Tuple + +# unchanged functions are imported from the original test file +from integrationtests.storage.integrationtest_storage import connect_to_storage, create_dataset_dir, \ + check_get_current_timestamp, check_dataset_availability, get_new_data_since, \ + get_data_in_interval, cleanup_dataset_dir, cleanup_storage_database +from modyn.storage.internal.grpc.generated.storage_pb2 import ( + GetRequest, RegisterNewDatasetRequest, +) +from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub + +SCRIPT_PATH = pathlib.Path(os.path.realpath(__file__)) + +TIMEOUT = 120 # seconds +CONFIG_FILE = SCRIPT_PATH.parent.parent.parent / "modyn" / "config" / "examples" / "modyn_config.yaml" +# The following path leads to a directory that is mounted into the docker container and shared with the +# storage container. +DATASET_PATH = pathlib.Path("/app") / "storage" / "datasets" / "test_dataset" + +# Because we have no mapping of file to key (happens in the storage service), we have to keep +# track of the samples we added to the dataset ourselves and compare them to the samples we get +# from the storage service. +FIRST_ADDED_CSVS = [] +SECOND_ADDED_CSVS = [] +CSV_UPDATED_TIME_STAMPS = [] + +def register_new_dataset() -> None: + storage_channel = connect_to_storage() + + storage = StorageStub(storage_channel) + + request = RegisterNewDatasetRequest( + base_path=str(DATASET_PATH), + dataset_id="test_dataset", + description="Test dataset for integration tests of CSV wrapper.", + file_wrapper_config=json.dumps({"file_extension": ".csv", "separator": ",", "label_index": 3}), + file_wrapper_type="CsvFileWrapper", + filesystem_wrapper_type="LocalFilesystemWrapper", + version="0.1.0", + ) + + response = storage.RegisterNewDataset(request) + + assert response.success, "Could not register new dataset." + +def add_file_to_dataset(csv_file_content: str, name: str) -> None: + with open(DATASET_PATH / name, "w") as f: + f.write(csv_file_content) + CSV_UPDATED_TIME_STAMPS.append(int(round(os.path.getmtime(DATASET_PATH / name) * 1000))) + + +def create_random_csv_row(file: int, counter: int) -> str: + index = random.randint(1, 1000) + return f"A{index}file{file},B{index}file{file},C{index}file{file},{counter}" + +def create_random_csv_file(file: int, counter: int) -> Tuple[str, list[str], int]: + rows = [] + samples = [] + for repeat in range(25): + row = create_random_csv_row(file, counter) + counter += 1 + rows.append(row) + sample = ",".join(row.split(",")[:3]) #remove the label + samples.append(sample) + + return "\n".join(rows), samples, counter + +def add_files_to_dataset(start_number: int, end_number: int, files_added: list[bytes], rows_added: list[bytes]) -> None: + create_dataset_dir() + counter = 0 + for i in range(start_number, end_number): + csv_file, samples_csv_file, counter = create_random_csv_file(i, counter) + add_file_to_dataset(csv_file, f"csv_{i}.csv") + files_added.append(bytes(csv_file, "utf-8")) + [rows_added.append(bytes(row, "utf-8")) for row in samples_csv_file] + + + +def check_data(keys: list[str], expected_samples: list[bytes]) -> None: + storage_channel = connect_to_storage() + + storage = StorageStub(storage_channel) + + request = GetRequest( + dataset_id="test_dataset", + keys=keys, + ) + samples_counter = 0 + for _ , response in enumerate(storage.Get(request)): + if len(response.samples) == 0: + assert False, f"Could not get sample with key {keys[samples_counter]}." + for sample in response.samples: + if sample is None: + assert False, f"Could not get sample with key {keys[samples_counter]}." + if sample not in expected_samples: + raise ValueError(f"Sample {sample} with key {keys[samples_counter]} is not present in the expected samples {expected_samples}. ") + samples_counter += 1 + assert samples_counter == len(keys), f"Could not get all samples. Samples missing: keys: {sorted(keys)} i: {samples_counter}" + + +def test_storage() -> None: + check_get_current_timestamp() # Check if the storage service is available. + create_dataset_dir() + add_files_to_dataset(0, 10, [], FIRST_ADDED_CSVS) # Add samples to the dataset. + register_new_dataset() + check_dataset_availability() # Check if the dataset is available. + + response = None + for i in range(500): + responses = list(get_new_data_since(0)) + assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}" + if len(responses) == 1: + response = responses[0] + if len(response.keys) == 250: #10 files, each one with 250 samples + break + time.sleep(1) + + assert response is not None, "Did not get any response from Storage" + assert len(response.keys) == 250, f"Not all samples were returned. Samples returned: {response.keys}" + + check_data(response.keys, FIRST_ADDED_CSVS) + + add_files_to_dataset(10, 20, [], SECOND_ADDED_CSVS) # Add more samples to the dataset. + + for i in range(500): + responses = list(get_new_data_since(CSV_UPDATED_TIME_STAMPS[9] + 1)) + assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}" + if len(responses) == 1: + response = responses[0] + if len(response.keys) == 250: + break + time.sleep(1) + + assert response is not None, "Did not get any response from Storage" + assert len(response.keys) == 250, f"Not all samples were returned. Samples returned: {response.keys}" + + check_data(response.keys, SECOND_ADDED_CSVS) + + responses = list(get_data_in_interval(0, CSV_UPDATED_TIME_STAMPS[9])) + assert len(responses) == 1, f"Received batched/no response, shouldn't happen: {responses}" + response = responses[0] + + check_data(response.keys, FIRST_ADDED_CSVS) + + check_get_current_timestamp() # Check if the storage service is still available. + + +def main() -> None: + try: + test_storage() + finally: + cleanup_dataset_dir() + cleanup_storage_database() + + +if __name__ == "__main__": + main() From 0e0dd5375bdf0577f6c36700611e6cbd1f8943c7 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Wed, 28 Jun 2023 09:05:31 +0200 Subject: [PATCH 16/17] CSV reader integration tests --- .../storage/integrationtest_storage_csv.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/integrationtests/storage/integrationtest_storage_csv.py b/integrationtests/storage/integrationtest_storage_csv.py index 6d5421fd7..ee1c3d854 100644 --- a/integrationtests/storage/integrationtest_storage_csv.py +++ b/integrationtests/storage/integrationtest_storage_csv.py @@ -1,7 +1,12 @@ -import io +############ +# storage integration tests adapted to CSV input format. +# Unchanged functions are imported from the original test +# Instead of images, we have CSV files. Each file has 25 rows end each row has 5 columns. +# f"A{index}file{file},B{index}file{file},C{index}file{file},{counter}" +# where index is a random number, file is the fileindex and the label (last column) is a global counter + import json import os -import pathlib import random import time from typing import Tuple @@ -9,19 +14,12 @@ # unchanged functions are imported from the original test file from integrationtests.storage.integrationtest_storage import connect_to_storage, create_dataset_dir, \ check_get_current_timestamp, check_dataset_availability, get_new_data_since, \ - get_data_in_interval, cleanup_dataset_dir, cleanup_storage_database + get_data_in_interval, cleanup_dataset_dir, cleanup_storage_database, DATASET_PATH from modyn.storage.internal.grpc.generated.storage_pb2 import ( GetRequest, RegisterNewDatasetRequest, ) from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub -SCRIPT_PATH = pathlib.Path(os.path.realpath(__file__)) - -TIMEOUT = 120 # seconds -CONFIG_FILE = SCRIPT_PATH.parent.parent.parent / "modyn" / "config" / "examples" / "modyn_config.yaml" -# The following path leads to a directory that is mounted into the docker container and shared with the -# storage container. -DATASET_PATH = pathlib.Path("/app") / "storage" / "datasets" / "test_dataset" # Because we have no mapping of file to key (happens in the storage service), we have to keep # track of the samples we added to the dataset ourselves and compare them to the samples we get @@ -54,7 +52,6 @@ def add_file_to_dataset(csv_file_content: str, name: str) -> None: f.write(csv_file_content) CSV_UPDATED_TIME_STAMPS.append(int(round(os.path.getmtime(DATASET_PATH / name) * 1000))) - def create_random_csv_row(file: int, counter: int) -> str: index = random.randint(1, 1000) return f"A{index}file{file},B{index}file{file},C{index}file{file},{counter}" From ddbcc6b61b53f9ada0182ffa122170b4445f3aa6 Mon Sep 17 00:00:00 2001 From: francescodeaglio Date: Wed, 28 Jun 2023 11:23:44 +0200 Subject: [PATCH 17/17] Compliance check on integration tests --- .../storage/integrationtest_storage_csv.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/integrationtests/storage/integrationtest_storage_csv.py b/integrationtests/storage/integrationtest_storage_csv.py index ee1c3d854..0cdf6679f 100644 --- a/integrationtests/storage/integrationtest_storage_csv.py +++ b/integrationtests/storage/integrationtest_storage_csv.py @@ -12,15 +12,20 @@ from typing import Tuple # unchanged functions are imported from the original test file -from integrationtests.storage.integrationtest_storage import connect_to_storage, create_dataset_dir, \ - check_get_current_timestamp, check_dataset_availability, get_new_data_since, \ - get_data_in_interval, cleanup_dataset_dir, cleanup_storage_database, DATASET_PATH -from modyn.storage.internal.grpc.generated.storage_pb2 import ( - GetRequest, RegisterNewDatasetRequest, +from integrationtests.storage.integrationtest_storage import ( + DATASET_PATH, + check_dataset_availability, + check_get_current_timestamp, + cleanup_dataset_dir, + cleanup_storage_database, + connect_to_storage, + create_dataset_dir, + get_data_in_interval, + get_new_data_since, ) +from modyn.storage.internal.grpc.generated.storage_pb2 import GetRequest, RegisterNewDatasetRequest from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub - # Because we have no mapping of file to key (happens in the storage service), we have to keep # track of the samples we added to the dataset ourselves and compare them to the samples we get # from the storage service. @@ -28,6 +33,7 @@ SECOND_ADDED_CSVS = [] CSV_UPDATED_TIME_STAMPS = [] + def register_new_dataset() -> None: storage_channel = connect_to_storage() @@ -47,15 +53,18 @@ def register_new_dataset() -> None: assert response.success, "Could not register new dataset." + def add_file_to_dataset(csv_file_content: str, name: str) -> None: with open(DATASET_PATH / name, "w") as f: f.write(csv_file_content) CSV_UPDATED_TIME_STAMPS.append(int(round(os.path.getmtime(DATASET_PATH / name) * 1000))) + def create_random_csv_row(file: int, counter: int) -> str: index = random.randint(1, 1000) return f"A{index}file{file},B{index}file{file},C{index}file{file},{counter}" + def create_random_csv_file(file: int, counter: int) -> Tuple[str, list[str], int]: rows = [] samples = [] @@ -63,11 +72,12 @@ def create_random_csv_file(file: int, counter: int) -> Tuple[str, list[str], int row = create_random_csv_row(file, counter) counter += 1 rows.append(row) - sample = ",".join(row.split(",")[:3]) #remove the label + sample = ",".join(row.split(",")[:3]) # remove the label samples.append(sample) return "\n".join(rows), samples, counter + def add_files_to_dataset(start_number: int, end_number: int, files_added: list[bytes], rows_added: list[bytes]) -> None: create_dataset_dir() counter = 0 @@ -78,7 +88,6 @@ def add_files_to_dataset(start_number: int, end_number: int, files_added: list[b [rows_added.append(bytes(row, "utf-8")) for row in samples_csv_file] - def check_data(keys: list[str], expected_samples: list[bytes]) -> None: storage_channel = connect_to_storage() @@ -89,16 +98,21 @@ def check_data(keys: list[str], expected_samples: list[bytes]) -> None: keys=keys, ) samples_counter = 0 - for _ , response in enumerate(storage.Get(request)): + for _, response in enumerate(storage.Get(request)): if len(response.samples) == 0: assert False, f"Could not get sample with key {keys[samples_counter]}." for sample in response.samples: if sample is None: assert False, f"Could not get sample with key {keys[samples_counter]}." if sample not in expected_samples: - raise ValueError(f"Sample {sample} with key {keys[samples_counter]} is not present in the expected samples {expected_samples}. ") + raise ValueError( + f"Sample {sample} with key {keys[samples_counter]} is not present in the " + f"expected samples {expected_samples}. " + ) samples_counter += 1 - assert samples_counter == len(keys), f"Could not get all samples. Samples missing: keys: {sorted(keys)} i: {samples_counter}" + assert samples_counter == len( + keys + ), f"Could not get all samples. Samples missing: keys: {sorted(keys)} i: {samples_counter}" def test_storage() -> None: @@ -114,7 +128,7 @@ def test_storage() -> None: assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}" if len(responses) == 1: response = responses[0] - if len(response.keys) == 250: #10 files, each one with 250 samples + if len(response.keys) == 250: # 10 files, each one with 250 samples break time.sleep(1)