diff --git a/integrationtests/run.sh b/integrationtests/run.sh index e8264f76f..6826ae1cb 100755 --- a/integrationtests/run.sh +++ b/integrationtests/run.sh @@ -10,6 +10,7 @@ python $SCRIPT_DIR/test_docker_compose.py python $SCRIPT_DIR/test_ftp_connections.py echo "Running storage integration tests" python $SCRIPT_DIR/storage/integrationtest_storage.py +python $SCRIPT_DIR/storage/integrationtest_storage_csv.py echo "Running selector integration tests" python $SCRIPT_DIR/selector/integrationtest_selector.py echo "Running model storage integration tests" diff --git a/integrationtests/storage/integrationtest_storage_csv.py b/integrationtests/storage/integrationtest_storage_csv.py new file mode 100644 index 000000000..0cdf6679f --- /dev/null +++ b/integrationtests/storage/integrationtest_storage_csv.py @@ -0,0 +1,174 @@ +############ +# storage integration tests adapted to CSV input format. +# Unchanged functions are imported from the original test +# Instead of images, we have CSV files. Each file has 25 rows end each row has 5 columns. +# f"A{index}file{file},B{index}file{file},C{index}file{file},{counter}" +# where index is a random number, file is the fileindex and the label (last column) is a global counter + +import json +import os +import random +import time +from typing import Tuple + +# unchanged functions are imported from the original test file +from integrationtests.storage.integrationtest_storage import ( + DATASET_PATH, + check_dataset_availability, + check_get_current_timestamp, + cleanup_dataset_dir, + cleanup_storage_database, + connect_to_storage, + create_dataset_dir, + get_data_in_interval, + get_new_data_since, +) +from modyn.storage.internal.grpc.generated.storage_pb2 import GetRequest, RegisterNewDatasetRequest +from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub + +# Because we have no mapping of file to key (happens in the storage service), we have to keep +# track of the samples we added to the dataset ourselves and compare them to the samples we get +# from the storage service. +FIRST_ADDED_CSVS = [] +SECOND_ADDED_CSVS = [] +CSV_UPDATED_TIME_STAMPS = [] + + +def register_new_dataset() -> None: + storage_channel = connect_to_storage() + + storage = StorageStub(storage_channel) + + request = RegisterNewDatasetRequest( + base_path=str(DATASET_PATH), + dataset_id="test_dataset", + description="Test dataset for integration tests of CSV wrapper.", + file_wrapper_config=json.dumps({"file_extension": ".csv", "separator": ",", "label_index": 3}), + file_wrapper_type="CsvFileWrapper", + filesystem_wrapper_type="LocalFilesystemWrapper", + version="0.1.0", + ) + + response = storage.RegisterNewDataset(request) + + assert response.success, "Could not register new dataset." + + +def add_file_to_dataset(csv_file_content: str, name: str) -> None: + with open(DATASET_PATH / name, "w") as f: + f.write(csv_file_content) + CSV_UPDATED_TIME_STAMPS.append(int(round(os.path.getmtime(DATASET_PATH / name) * 1000))) + + +def create_random_csv_row(file: int, counter: int) -> str: + index = random.randint(1, 1000) + return f"A{index}file{file},B{index}file{file},C{index}file{file},{counter}" + + +def create_random_csv_file(file: int, counter: int) -> Tuple[str, list[str], int]: + rows = [] + samples = [] + for repeat in range(25): + row = create_random_csv_row(file, counter) + counter += 1 + rows.append(row) + sample = ",".join(row.split(",")[:3]) # remove the label + samples.append(sample) + + return "\n".join(rows), samples, counter + + +def add_files_to_dataset(start_number: int, end_number: int, files_added: list[bytes], rows_added: list[bytes]) -> None: + create_dataset_dir() + counter = 0 + for i in range(start_number, end_number): + csv_file, samples_csv_file, counter = create_random_csv_file(i, counter) + add_file_to_dataset(csv_file, f"csv_{i}.csv") + files_added.append(bytes(csv_file, "utf-8")) + [rows_added.append(bytes(row, "utf-8")) for row in samples_csv_file] + + +def check_data(keys: list[str], expected_samples: list[bytes]) -> None: + storage_channel = connect_to_storage() + + storage = StorageStub(storage_channel) + + request = GetRequest( + dataset_id="test_dataset", + keys=keys, + ) + samples_counter = 0 + for _, response in enumerate(storage.Get(request)): + if len(response.samples) == 0: + assert False, f"Could not get sample with key {keys[samples_counter]}." + for sample in response.samples: + if sample is None: + assert False, f"Could not get sample with key {keys[samples_counter]}." + if sample not in expected_samples: + raise ValueError( + f"Sample {sample} with key {keys[samples_counter]} is not present in the " + f"expected samples {expected_samples}. " + ) + samples_counter += 1 + assert samples_counter == len( + keys + ), f"Could not get all samples. Samples missing: keys: {sorted(keys)} i: {samples_counter}" + + +def test_storage() -> None: + check_get_current_timestamp() # Check if the storage service is available. + create_dataset_dir() + add_files_to_dataset(0, 10, [], FIRST_ADDED_CSVS) # Add samples to the dataset. + register_new_dataset() + check_dataset_availability() # Check if the dataset is available. + + response = None + for i in range(500): + responses = list(get_new_data_since(0)) + assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}" + if len(responses) == 1: + response = responses[0] + if len(response.keys) == 250: # 10 files, each one with 250 samples + break + time.sleep(1) + + assert response is not None, "Did not get any response from Storage" + assert len(response.keys) == 250, f"Not all samples were returned. Samples returned: {response.keys}" + + check_data(response.keys, FIRST_ADDED_CSVS) + + add_files_to_dataset(10, 20, [], SECOND_ADDED_CSVS) # Add more samples to the dataset. + + for i in range(500): + responses = list(get_new_data_since(CSV_UPDATED_TIME_STAMPS[9] + 1)) + assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}" + if len(responses) == 1: + response = responses[0] + if len(response.keys) == 250: + break + time.sleep(1) + + assert response is not None, "Did not get any response from Storage" + assert len(response.keys) == 250, f"Not all samples were returned. Samples returned: {response.keys}" + + check_data(response.keys, SECOND_ADDED_CSVS) + + responses = list(get_data_in_interval(0, CSV_UPDATED_TIME_STAMPS[9])) + assert len(responses) == 1, f"Received batched/no response, shouldn't happen: {responses}" + response = responses[0] + + check_data(response.keys, FIRST_ADDED_CSVS) + + check_get_current_timestamp() # Check if the storage service is still available. + + +def main() -> None: + try: + test_storage() + finally: + cleanup_dataset_dir() + cleanup_storage_database() + + +if __name__ == "__main__": + main() diff --git a/modyn/config/schema/modyn_config_schema.yaml b/modyn/config/schema/modyn_config_schema.yaml index 5cb8b90b8..289518c0d 100644 --- a/modyn/config/schema/modyn_config_schema.yaml +++ b/modyn/config/schema/modyn_config_schema.yaml @@ -37,7 +37,8 @@ properties: sample_batch_size: type: number description: | - The size of a batch when requesting new samples from storage. All new samples are returned, however, to reduce the size of a single answer the keys are batched in sizes of `sample_batch_size`. + The size of a batch when requesting new samples from storage. All new samples are returned, however, to reduce + the size of a single answer the keys are batched in sizes of `sample_batch_size`. sample_dbinsertion_batchsize: type: number description: | @@ -49,7 +50,9 @@ properties: sample_table_unlogged: type: boolean description: | - This configures whether the table storing all samples is UNLOGGED (= high performance) or crash resilient. Defaults to True. For datasets with many samples (such as Criteo), this is recommended for highest insertion performance. In other scenarios, this might not be necessary. + This configures whether the table storing all samples is UNLOGGED (= high performance) or crash resilient. + Defaults to True. For datasets with many samples (such as Criteo), this is recommended for highest insertion performance. + In other scenarios, this might not be necessary. force_fallback_insert: type: boolean description: | @@ -99,15 +102,39 @@ properties: record_size: type: number description: | - The size of each full record in bytes (label + features) for a binary file wrapper. + [BinaryFileWrapper] The size of each full record in bytes (label + features). label_size: type: number description: | - The size of the label field in bytes for a binary file wrapper. + [BinaryFileWrapper] The size of the label field in bytes for a binary file wrapper. byteorder: type: string description: | - The byteorder when reading an integer from multibyte data in a binary file. Should either be "big" or "little". + [BinaryFileWrapper] The byteorder when reading an integer from multibyte data in a binary file. + Should either be "big" or "little". + separator: + type: string + description: | + [CsvFileWrapper] The separator used in the CSV file. The default is ",". + label_index: + type: number + description: | + [CsvFileWrapper] Column index of the label. + For example, if the columns are "width", "height", "age", "label" you should set label_index to 3. + ignore_first_line: + type: boolean + description: | + [CsvFileWrapper] If the first line is the table header, you can skip it setting this parameter to True. + Default is False. + encoding: + type: string + description: | + [CsvFileWrapper] Encoding of the CSV file. Default is utf-8. + validate_file_content: + type: boolean + description: | + [CsvFileWrapper] Whether to validate the file content before inserting the data. It checks that it + is a csv, that all rows are the same size and that the 'label' column exists. Default is True ignore_last_timestamp: type: boolean description: | diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py new file mode 100644 index 000000000..355fb5918 --- /dev/null +++ b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py @@ -0,0 +1,193 @@ +import csv +from typing import Iterator, Optional + +from modyn.storage.internal.file_wrapper.abstract_file_wrapper import AbstractFileWrapper +from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType +from modyn.storage.internal.filesystem_wrapper.abstract_filesystem_wrapper import AbstractFileSystemWrapper + + +class CsvFileWrapper(AbstractFileWrapper): + def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper: AbstractFileSystemWrapper): + super().__init__(file_path, file_wrapper_config, filesystem_wrapper) + + self.file_wrapper_type = FileWrapperType.CsvFileWrapper + + if "separator" in file_wrapper_config: + self.separator = file_wrapper_config["separator"] + else: + self.separator = "," + + if "label_index" not in file_wrapper_config: + raise ValueError("Please specify the index of the column that contains the label. ") + if not isinstance(file_wrapper_config["label_index"], int) or file_wrapper_config["label_index"] < 0: + raise ValueError("The label_index must be a positive integer.") + self.label_index = file_wrapper_config["label_index"] + + # the first line might contain the header, which is useless and must not be returned. + if "ignore_first_line" in file_wrapper_config: + self.ignore_first_line = file_wrapper_config["ignore_first_line"] + else: + self.ignore_first_line = False + + if "encoding" in file_wrapper_config: + self.encoding = file_wrapper_config["encoding"] + else: + self.encoding = "utf-8" + + # check that the file is actually a CSV + self._validate_file_extension() + + # do not validate the content only if "validate_file_content" is explicitly set to False + if ("validate_file_content" not in file_wrapper_config) or ( + "validate_file_content" in file_wrapper_config and file_wrapper_config["validate_file_content"] + ): + self._validate_file_content() + + def _validate_file_extension(self) -> None: + """Validates the file extension as csv + + Raises: + ValueError: File has wrong file extension + """ + if not self.file_path.endswith(".csv"): + raise ValueError("File has wrong file extension.") + + def _validate_file_content(self) -> None: + """ + Performs the following checks: + - specified label column is castable to integer + - each row has the label_index_column + - each row has the same width + + Raises a ValueError if a condition is not met + """ + + reader = self._get_csv_reader() + + number_of_columns = [] + + for row in reader: + number_of_columns.append(len(row)) + if not 0 <= self.label_index < len(row): + raise ValueError("Label index outside row boundary") + if not row[self.label_index].isnumeric(): # returns true iff all the characters are numbers + raise ValueError("The label must be an integer") + + if len(set(number_of_columns)) != 1: + raise ValueError( + "Some rows have different width. " f"This is the number of columns row by row {number_of_columns}" + ) + + def get_sample(self, index: int) -> bytes: + samples = self._filter_rows_samples([index]) + + if len(samples) != 1: + raise IndexError("Invalid index") + + return samples[0] + + def get_samples(self, start: int, end: int) -> list[bytes]: + indices = list(range(start, end)) + return self.get_samples_from_indices(indices) + + def get_samples_from_indices(self, indices: list) -> list[bytes]: + return self._filter_rows_samples(indices) + + def get_label(self, index: int) -> int: + labels = self._filter_rows_labels([index]) + + if len(labels) != 1: + raise IndexError("Invalid index.") + + return labels[0] + + def get_all_labels(self) -> list[int]: + reader = self._get_csv_reader() + labels = [int(row[self.label_index]) for row in reader] + return labels + + def get_number_of_samples(self) -> int: + reader = self._get_csv_reader() + return sum(1 for _ in reader) + + def _get_csv_reader(self) -> Iterator: + """ + Receives the bytes from the file_system_wrapper and creates a csv.reader out of it. + Returns: + csv.reader + """ + data_file = self.filesystem_wrapper.get(self.file_path) + + # Convert bytes content to a string + data_file_str = data_file.decode(self.encoding) + + lines = data_file_str.split("\n") + + # Create a CSV reader + reader = csv.reader(lines, delimiter=self.separator) + + # skip the header if required + if self.ignore_first_line: + next(reader) + + return reader + + def _filter_rows_samples(self, indices: list[int]) -> list[bytes]: + """ + Filters the selected rows and removes the label column + Args: + indices: list of rows that must be kept + + Returns: + list of byte-encoded rows + + """ + assert len(indices) == len(set(indices)), "An index is required more than once." + reader = self._get_csv_reader() + + # Iterate over the rows and keep the selected ones + filtered_rows: list[Optional[bytes]] = [None] * len(indices) + for i, row in enumerate(reader): + if i in indices: + # Remove the label, convert the row to bytes and append to the list + row_without_label = [col for j, col in enumerate(row) if j != self.label_index] + # the row is transformed in a similar csv using the same separator and then transformed to bytes + filtered_rows[indices.index(i)] = bytes(self.separator.join(row_without_label), self.encoding) + + if sum(1 for el in filtered_rows if el is None) != 0: + raise IndexError("At least one index is invalid") + + # Here mypy complains that filtered_rows is a list of list[Optional[bytes]], + # that can't happen given the above exception + return filtered_rows # type: ignore + + def _filter_rows_labels(self, indices: list[int]) -> list[int]: + """ + Filters the selected rows and extracts the label column + Args: + indices: list of rows that must be kept + + Returns: + list of labels + + """ + assert len(indices) == len(set(indices)), "An index is required more than once." + reader = self._get_csv_reader() + + # Iterate over the rows and keep the selected ones + filtered_rows: list[Optional[int]] = [None] * len(indices) + for i, row in enumerate(reader): + if i in indices: + # labels are integer in modyn + int_label = int(row[self.label_index]) + filtered_rows[indices.index(i)] = int_label + + if sum(1 for el in filtered_rows if el is None) != 0: + raise IndexError("At least one index is invalid") + + # Here mypy complains that filtered_rows is a list of list[Optional[bytes]], + # that can't happen given the above exception + return filtered_rows # type: ignore + + def delete_samples(self, indices: list) -> None: + pass diff --git a/modyn/storage/internal/file_wrapper/file_wrapper_type.py b/modyn/storage/internal/file_wrapper/file_wrapper_type.py index 4a9f1c32b..758b99cbe 100644 --- a/modyn/storage/internal/file_wrapper/file_wrapper_type.py +++ b/modyn/storage/internal/file_wrapper/file_wrapper_type.py @@ -12,6 +12,7 @@ class FileWrapperType(Enum): SingleSampleFileWrapper = "single_sample_file_wrapper" # pylint: disable=invalid-name BinaryFileWrapper = "binary_file_wrapper" # pylint: disable=invalid-name + CsvFileWrapper = "csv_file_wrapper" # pylint: disable=invalid-name class InvalidFileWrapperTypeException(Exception): diff --git a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py new file mode 100644 index 000000000..b345b574e --- /dev/null +++ b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py @@ -0,0 +1,278 @@ +import os +import pathlib +import shutil + +import pytest +from modyn.storage.internal.file_wrapper.csv_file_wrapper import CsvFileWrapper +from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType + +TMP_DIR = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn") +FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.csv") +CUSTOM_FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "wrong_test.csv") +FILE_DATA = b"a;b;c;d;12\ne;f;g;h;76" +INVALID_FILE_EXTENSION_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.txt") +FILE_WRAPPER_CONFIG = { + "ignore_first_line": False, + "label_index": 4, + "separator": ";", +} + + +def setup(): + os.makedirs(TMP_DIR, exist_ok=True) + + with open(FILE_PATH, "wb") as file: + file.write(FILE_DATA) + + +def teardown(): + os.remove(FILE_PATH) + shutil.rmtree(TMP_DIR) + + +class MockFileSystemWrapper: + def __init__(self, file_path): + self.file_path = file_path + + def get(self, file_path): + with open(file_path, "rb") as file: + return file.read() + + def get_size(self, path): + return os.path.getsize(path) + + +def test_init(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + assert file_wrapper.file_path == FILE_PATH + assert file_wrapper.file_wrapper_type == FileWrapperType.CsvFileWrapper + assert file_wrapper.encoding == "utf-8" + assert file_wrapper.label_index == 4 + assert not file_wrapper.ignore_first_line + assert file_wrapper.separator == ";" + + +def test_init_with_invalid_file_extension(): + with pytest.raises(ValueError): + CsvFileWrapper( + INVALID_FILE_EXTENSION_PATH, + FILE_WRAPPER_CONFIG, + MockFileSystemWrapper(INVALID_FILE_EXTENSION_PATH), + ) + + +def test_get_number_of_samples(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + assert file_wrapper.get_number_of_samples() == 2 + + # check if the first line is correctly ignored + file_wrapper.ignore_first_line = True + assert file_wrapper.get_number_of_samples() == 1 + + +def test_get_sample(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + sample = file_wrapper.get_sample(0) + assert sample == b"a;b;c;d" + + sample = file_wrapper.get_sample(1) + assert sample == b"e;f;g;h" + + +def test_get_sample_with_invalid_index(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + with pytest.raises(IndexError): + file_wrapper.get_sample(10) + + +def test_get_label(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + label = file_wrapper.get_label(0) + assert label == 12 + + label = file_wrapper.get_label(1) + assert label == 76 + + with pytest.raises(IndexError): + file_wrapper.get_label(2) + + +def test_get_all_labels(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + assert file_wrapper.get_all_labels() == [12, 76] + + +def test_get_label_with_invalid_index(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + with pytest.raises(IndexError): + file_wrapper.get_label(10) + + +def test_get_samples(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + samples = file_wrapper.get_samples(0, 1) + assert len(samples) == 1 + assert samples[0] == b"a;b;c;d" + + samples = file_wrapper.get_samples(0, 2) + assert len(samples) == 2 + assert samples[0] == b"a;b;c;d" + assert samples[1] == b"e;f;g;h" + + +def test_get_samples_with_invalid_index(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + with pytest.raises(IndexError): + file_wrapper.get_samples(0, 5) + + with pytest.raises(IndexError): + file_wrapper.get_samples(3, 4) + + +def test_get_samples_from_indices_with_invalid_indices(): + file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + with pytest.raises(IndexError): + file_wrapper.get_samples_from_indices([-2, 1]) + + +def write_to_file(data): + with open(CUSTOM_FILE_PATH, "wb") as file: + file.write(data) + + +def test_invalid_file_content(): + # extra field in one row + wrong_data = b"a;b;c;d;12;e\ne;f;g;h;76" + write_to_file(wrong_data) + + with pytest.raises(ValueError): + _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + # label column outside boundary + wrong_data = b"a;b;c;12\ne;f;g;76" + write_to_file(wrong_data) + + with pytest.raises(ValueError): + _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + # str label column + wrong_data = b"a;b;c;d;e;12\ne;f;g;h;h;76" + write_to_file(wrong_data) + with pytest.raises(ValueError): + _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + # just one str in label + wrong_data = b"a;b;c;d;88;12\ne;f;g;h;h;76" + write_to_file(wrong_data) + with pytest.raises(ValueError): + _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + +def test_invalid_file_content_skip_validation(): + # extra field in one row + wrong_data = b"a;b;c;d;12;e\ne;f;g;h;76" + write_to_file(wrong_data) + + config = FILE_WRAPPER_CONFIG.copy() + config["validate_file_content"] = False + + _ = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + # label column outside boundary + wrong_data = b"a;b;c;12\ne;f;g;76" + write_to_file(wrong_data) + + file_wrapper = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + with pytest.raises(IndexError): # fails since index > number of columns + file_wrapper.get_label(1) + + # str label column + wrong_data = b"a;b;c;d;e;12\ne;f;g;h;h;76" + write_to_file(wrong_data) + CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + with pytest.raises(ValueError): # fails to convert to integer + file_wrapper.get_label(1) + + # just one str in label + wrong_data = b"a;b;c;d;88;12\ne;f;g;h;h;76" + write_to_file(wrong_data) + CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + file_wrapper.get_label(0) # does not fail since row 0 is ok + with pytest.raises(ValueError): # fails to convert to integer + file_wrapper.get_label(1) + + +def test_different_separator(): + tsv_file_data = b"a\tb\tc\td\t12\ne\tf\tg\th\t76" + + tsv_file_wrapper_config = { + "ignore_first_line": False, + "label_index": 4, + "separator": "\t", + } + + write_to_file(tsv_file_data) + tsv_file_wrapper = CsvFileWrapper( + CUSTOM_FILE_PATH, tsv_file_wrapper_config, MockFileSystemWrapper(CUSTOM_FILE_PATH) + ) + csv_file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) + + assert tsv_file_wrapper.get_number_of_samples() == csv_file_wrapper.get_number_of_samples() + + assert tsv_file_wrapper.get_sample(0) == b"a\tb\tc\td" + assert tsv_file_wrapper.get_sample(1) == b"e\tf\tg\th" + + tsv_samples = tsv_file_wrapper.get_samples(0, 2) + csv_samples = csv_file_wrapper.get_samples(0, 2) + + tsv_samples = [sample.decode("utf-8").split("\t") for sample in tsv_samples] + csv_samples = [sample.decode("utf-8").split(";") for sample in csv_samples] + assert tsv_samples == csv_samples + + assert tsv_file_wrapper.get_label(0) == csv_file_wrapper.get_label(0) + assert tsv_file_wrapper.get_label(1) == csv_file_wrapper.get_label(1) + + +def test_out_of_order_sequence(): + content = b"A1;B1;C1;1\nA2;B2;C2;2\nA3;B3;C3;3\nA4;B4;C4;4\nA5;B5;C5;5" + converted = [b"A1;B1;C1", b"A2;B2;C2", b"A3;B3;C3", b"A4;B4;C4", b"A5;B5;C5"] + write_to_file(content) + config = { + "ignore_first_line": False, + "label_index": 3, + "separator": ";", + } + file_wrapper = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + # samples + assert file_wrapper.get_samples_from_indices([2, 1]) == [converted[2], converted[1]] + assert file_wrapper.get_samples_from_indices([3, 2, 1]) == [converted[3], converted[2], converted[1]] + assert file_wrapper.get_samples_from_indices([3, 2, 4, 1]) == [ + converted[3], + converted[2], + converted[4], + converted[1], + ] + + +def test_duplicate_request(): + content = b"A1;B1;C1;1\nA2;B2;C2;2\nA3;B3;C3;3\nA4;B4;C4;4\nA5;B5;C5;5" + write_to_file(content) + config = { + "ignore_first_line": False, + "label_index": 3, + "separator": ";", + } + file_wrapper = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) + + with pytest.raises(AssertionError): + file_wrapper.get_samples_from_indices([1, 1]) + + with pytest.raises(AssertionError): + file_wrapper.get_samples_from_indices([1, 1, 3]) + + with pytest.raises(AssertionError): + file_wrapper.get_samples_from_indices([1, 1, 13])