diff --git a/mltb2/files.py b/mltb2/files.py index 03da345..b043257 100644 --- a/mltb2/files.py +++ b/mltb2/files.py @@ -192,7 +192,7 @@ def save_batch(self, batch: Sequence[Dict[str, Any]]) -> None: self._remove_lock_files(batch) @staticmethod - def load_data(result_dir: str) -> List[Dict[str, Any]]: + def load_data(result_dir: str, ignore_load_error: bool = False) -> List[Dict[str, Any]]: """Load all data. After all data is processed, this method can be used to load all data. @@ -201,6 +201,7 @@ def load_data(result_dir: str) -> List[Dict[str, Any]]: Args: result_dir: The directory where the results are stored. + ignore_load_error: Ignore errors when loading the result files. Just print them. """ _result_dir_path = Path(result_dir) if not _result_dir_path.is_dir(): @@ -212,6 +213,17 @@ def load_data(result_dir: str) -> List[Dict[str, Any]]: if child_path.is_file() and child_path.name.endswith(".pkl.gz"): uuid = FileBasedRestartableBatchDataProcessor._get_uuid_from_filename(child_path.name) if uuid not in uuids: - uuids.add(uuid) - data.append(joblib.load(child_path)) + + d = None + try: + d = joblib.load(child_path) + except Exception as e: + if ignore_load_error: + print(f"Error loading file '{child_path}': {e}") + else: + raise e # NOQA: TRY201 + + if d is not None: + uuids.add(uuid) + data.append(d) return data diff --git a/pyproject.toml b/pyproject.toml index bd6d9f4..4b46417 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mltb2" -version = "1.0.0rc2" +version = "1.0.0rc3" description = "Machine Learning Toolbox 2" authors = ["PhilipMay "] readme = "README.md" diff --git a/tests/test_files.py b/tests/test_files.py index 14297aa..5b49089 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -254,3 +254,51 @@ def test_FileBasedRestartableBatchDataProcessor_clear_lock_files(tmp_path): assert isinstance(d["uuid"], str) assert isinstance(d["x"], int) assert d["x"] < 100 + + +def test_FileBasedRestartableBatchDataProcessor_load_data_with_error(tmp_path): + result_dir = tmp_path.absolute() + batch_size = 10 + data = [{"uuid": str(uuid4()), "x": i} for i in range(100)] + data_processor = FileBasedRestartableBatchDataProcessor( + data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir + ) + + # process all data + while True: + _data = data_processor.read_batch() + if len(_data) == 0: + break + data_processor.save_batch(_data) + + first_data_file = list(tmp_path.glob("*.pkl.gz"))[0] + with open(first_data_file, "w") as f: + f.write("") + + del data_processor + with pytest.raises(EOFError): + _ = FileBasedRestartableBatchDataProcessor.load_data(result_dir) + + +def test_FileBasedRestartableBatchDataProcessor_load_data_ignore_error(tmp_path): + result_dir = tmp_path.absolute() + batch_size = 10 + data = [{"uuid": str(uuid4()), "x": i} for i in range(100)] + data_processor = FileBasedRestartableBatchDataProcessor( + data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir + ) + + # process all data + while True: + _data = data_processor.read_batch() + if len(_data) == 0: + break + data_processor.save_batch(_data) + + first_data_file = list(tmp_path.glob("*.pkl.gz"))[0] + with open(first_data_file, "w") as f: + f.write("") + + del data_processor + processed_data = FileBasedRestartableBatchDataProcessor.load_data(result_dir, ignore_load_error=True) + assert len(processed_data) == len(data) - 1