Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use a transactional write in storage
Browse files Browse the repository at this point in the history
eivindjahren committed Sep 26, 2024
1 parent 21af329 commit 9c23db2
Showing 5 changed files with 86 additions and 13 deletions.
16 changes: 16 additions & 0 deletions src/ert/storage/_write_transaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations

import os


def write_transaction(filename: str | os.PathLike[str], data: bytes) -> None:
"""Writes the data to the filename as a transaction.
Guarantees to not leave half-written or empty files on disk if the write
fails or the process is killed.
"""
swapfile = str(filename) + ".swp"
with open(swapfile, mode="wb") as f:
f.write(data)

os.rename(swapfile, filename)
7 changes: 3 additions & 4 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
from ert.config.gen_kw_config import GenKwConfig
from ert.storage.mode import BaseMode, Mode, require_write

from ._write_transaction import write_transaction
from .realization_storage_state import RealizationStorageState

if TYPE_CHECKING:
@@ -143,8 +144,7 @@ def create(
started_at=datetime.now(),
)

with open(path / "index.json", mode="w", encoding="utf-8") as f:
print(index.model_dump_json(), file=f)
write_transaction(path / "index.json", index.model_dump_json().encode("utf-8"))

return cls(storage, path, Mode.WRITE)

@@ -422,8 +422,7 @@ def set_failure(
error = _Failure(
type=failure_type, message=message if message else "", time=datetime.now()
)
with open(filename, mode="w", encoding="utf-8") as f:
print(error.model_dump_json(), file=f)
write_transaction(filename, error.model_dump_json().encode("utf-8"))

def unset_failure(
self,
22 changes: 15 additions & 7 deletions src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,8 @@
from ert.config.response_config import ResponseConfig
from ert.storage.mode import BaseMode, Mode, require_write

from ._write_transaction import write_transaction

if TYPE_CHECKING:
from ert.config.parameter_config import ParameterConfig
from ert.storage.local_ensemble import LocalEnsemble
@@ -129,24 +131,30 @@ def create(
for parameter in parameters or []:
parameter.save_experiment_data(path)
parameter_data.update({parameter.name: parameter.to_dict()})
with open(path / cls._parameter_file, "w", encoding="utf-8") as f:
json.dump(parameter_data, f, indent=2)
write_transaction(
path / cls._parameter_file,
json.dumps(parameter_data, indent=2).encode("utf-8"),
)

response_data = {}
for response in responses or []:
response_data.update({response.response_type: response.to_dict()})
with open(path / cls._responses_file, "w", encoding="utf-8") as f:
json.dump(response_data, f, default=str, indent=2)
write_transaction(
path / cls._responses_file,
json.dumps(response_data, default=str, indent=2).encode("utf-8"),
)

if observations:
output_path = path / "observations"
output_path.mkdir()
for obs_name, dataset in observations.items():
dataset.to_netcdf(output_path / f"{obs_name}", engine="scipy")

with open(path / cls._metadata_file, "w", encoding="utf-8") as f:
simulation_data = simulation_arguments if simulation_arguments else {}
json.dump(simulation_data, f, cls=ContextBoolEncoder)
simulation_data = simulation_arguments if simulation_arguments else {}
write_transaction(
path / cls._metadata_file,
json.dumps(simulation_data, cls=ContextBoolEncoder).encode("utf-8"),
)

return cls(storage, path, Mode.WRITE)

8 changes: 6 additions & 2 deletions src/ert/storage/local_storage.py
Original file line number Diff line number Diff line change
@@ -38,6 +38,8 @@
)
from ert.storage.realization_storage_state import RealizationStorageState

from ._write_transaction import write_transaction

if TYPE_CHECKING:
from ert.config import ParameterConfig, ResponseConfig

@@ -446,8 +448,10 @@ def _add_migration_information(

@require_write
def _save_index(self) -> None:
with open(self.path / "index.json", mode="w", encoding="utf-8") as f:
print(self._index.model_dump_json(indent=4), file=f)
write_transaction(
self.path / "index.json",
self._index.model_dump_json(indent=4).encode("utf-8"),
)

@require_write
def _migrate(self, version: int) -> None:
46 changes: 46 additions & 0 deletions tests/ert/unit_tests/storage/test_write_transaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pathlib import Path
from unittest.mock import mock_open, patch

import hypothesis.strategies as st
import pytest
from hypothesis import given

from ert.storage._write_transaction import write_transaction


@pytest.mark.usefixtures("use_tmpdir")
@given(st.binary())
def test_write_transaction(data):
filepath = Path("./file.txt")
write_transaction(filepath, data)

assert filepath.read_bytes() == data


def test_write_transaction_failure(tmp_path):
with patch("builtins.open", mock_open()) as m:
handle = m()

def ctrlc(_):
raise RuntimeError()

handle.write = ctrlc

path = tmp_path / "file.txt"
with pytest.raises(RuntimeError):
write_transaction(path, b"deadbeaf")

assert not [
c
for c in m.mock_calls
if path in c.args
or str(path) in c.args
or c.kwargs.get("file") in [path, str(path)]
], "There should be no calls opening the file when an write encounters a RuntimeError"


def test_write_transaction_overwrites(tmp_path):
path = tmp_path / "file.txt"
path.write_text("abc")
write_transaction(path, b"deadbeaf")
assert path.read_bytes() == b"deadbeaf"

0 comments on commit 9c23db2

Please sign in to comment.