Skip to content

Commit

Permalink
Add ignore error in FileBasedRestartableBatchDataProcessor data loadi…
Browse files Browse the repository at this point in the history
…ng. (#163)

* add ignore_load_error

* bump rc

* update doc
  • Loading branch information
PhilipMay authored May 30, 2024
1 parent 57f3a0e commit d8dc066
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 4 deletions.
18 changes: 15 additions & 3 deletions mltb2/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def save_batch(self, batch: Sequence[Dict[str, Any]]) -> None:
self._remove_lock_files(batch)

@staticmethod
def load_data(result_dir: str) -> List[Dict[str, Any]]:
def load_data(result_dir: str, ignore_load_error: bool = False) -> List[Dict[str, Any]]:
"""Load all data.
After all data is processed, this method can be used to load all data.
Expand All @@ -201,6 +201,7 @@ def load_data(result_dir: str) -> List[Dict[str, Any]]:
Args:
result_dir: The directory where the results are stored.
ignore_load_error: Ignore errors when loading the result files. Just print them.
"""
_result_dir_path = Path(result_dir)
if not _result_dir_path.is_dir():
Expand All @@ -212,6 +213,17 @@ def load_data(result_dir: str) -> List[Dict[str, Any]]:
if child_path.is_file() and child_path.name.endswith(".pkl.gz"):
uuid = FileBasedRestartableBatchDataProcessor._get_uuid_from_filename(child_path.name)
if uuid not in uuids:
uuids.add(uuid)
data.append(joblib.load(child_path))

d = None
try:
d = joblib.load(child_path)
except Exception as e:
if ignore_load_error:
print(f"Error loading file '{child_path}': {e}")
else:
raise e # NOQA: TRY201

if d is not None:
uuids.add(uuid)
data.append(d)
return data
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mltb2"
version = "1.0.0rc2"
version = "1.0.0rc3"
description = "Machine Learning Toolbox 2"
authors = ["PhilipMay <[email protected]>"]
readme = "README.md"
Expand Down
48 changes: 48 additions & 0 deletions tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,51 @@ def test_FileBasedRestartableBatchDataProcessor_clear_lock_files(tmp_path):
assert isinstance(d["uuid"], str)
assert isinstance(d["x"], int)
assert d["x"] < 100


def test_FileBasedRestartableBatchDataProcessor_load_data_with_error(tmp_path):
result_dir = tmp_path.absolute()
batch_size = 10
data = [{"uuid": str(uuid4()), "x": i} for i in range(100)]
data_processor = FileBasedRestartableBatchDataProcessor(
data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir
)

# process all data
while True:
_data = data_processor.read_batch()
if len(_data) == 0:
break
data_processor.save_batch(_data)

first_data_file = list(tmp_path.glob("*.pkl.gz"))[0]
with open(first_data_file, "w") as f:
f.write("")

del data_processor
with pytest.raises(EOFError):
_ = FileBasedRestartableBatchDataProcessor.load_data(result_dir)


def test_FileBasedRestartableBatchDataProcessor_load_data_ignore_error(tmp_path):
result_dir = tmp_path.absolute()
batch_size = 10
data = [{"uuid": str(uuid4()), "x": i} for i in range(100)]
data_processor = FileBasedRestartableBatchDataProcessor(
data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir
)

# process all data
while True:
_data = data_processor.read_batch()
if len(_data) == 0:
break
data_processor.save_batch(_data)

first_data_file = list(tmp_path.glob("*.pkl.gz"))[0]
with open(first_data_file, "w") as f:
f.write("")

del data_processor
processed_data = FileBasedRestartableBatchDataProcessor.load_data(result_dir, ignore_load_error=True)
assert len(processed_data) == len(data) - 1

0 comments on commit d8dc066

Please sign in to comment.