Skip to content

Commit

Permalink
Specify directory for write_transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Sep 26, 2024
1 parent 9c23db2 commit 31af4fe
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 74 deletions.
16 changes: 0 additions & 16 deletions src/ert/storage/_write_transaction.py

This file was deleted.

9 changes: 6 additions & 3 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ert.config.gen_kw_config import GenKwConfig
from ert.storage.mode import BaseMode, Mode, require_write

from ._write_transaction import write_transaction
from .realization_storage_state import RealizationStorageState

if TYPE_CHECKING:
Expand Down Expand Up @@ -144,7 +143,9 @@ def create(
started_at=datetime.now(),
)

write_transaction(path / "index.json", index.model_dump_json().encode("utf-8"))
storage._write_transaction(
path / "index.json", index.model_dump_json().encode("utf-8")
)

return cls(storage, path, Mode.WRITE)

Expand Down Expand Up @@ -422,7 +423,9 @@ def set_failure(
error = _Failure(
type=failure_type, message=message if message else "", time=datetime.now()
)
write_transaction(filename, error.model_dump_json().encode("utf-8"))
self._storage._write_transaction(
filename, error.model_dump_json().encode("utf-8")
)

def unset_failure(
self,
Expand Down
8 changes: 3 additions & 5 deletions src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from ert.config.response_config import ResponseConfig
from ert.storage.mode import BaseMode, Mode, require_write

from ._write_transaction import write_transaction

if TYPE_CHECKING:
from ert.config.parameter_config import ParameterConfig
from ert.storage.local_ensemble import LocalEnsemble
Expand Down Expand Up @@ -131,15 +129,15 @@ def create(
for parameter in parameters or []:
parameter.save_experiment_data(path)
parameter_data.update({parameter.name: parameter.to_dict()})
write_transaction(
storage._write_transaction(
path / cls._parameter_file,
json.dumps(parameter_data, indent=2).encode("utf-8"),
)

response_data = {}
for response in responses or []:
response_data.update({response.response_type: response.to_dict()})
write_transaction(
storage._write_transaction(
path / cls._responses_file,
json.dumps(response_data, default=str, indent=2).encode("utf-8"),
)
Expand All @@ -151,7 +149,7 @@ def create(
dataset.to_netcdf(output_path / f"{obs_name}", engine="scipy")

simulation_data = simulation_arguments if simulation_arguments else {}
write_transaction(
storage._write_transaction(
path / cls._metadata_file,
json.dumps(simulation_data, cls=ContextBoolEncoder).encode("utf-8"),
)
Expand Down
23 changes: 20 additions & 3 deletions src/ert/storage/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import os
import shutil
from datetime import datetime
from functools import cached_property
from pathlib import Path
from tempfile import NamedTemporaryFile
from textwrap import dedent
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -38,8 +40,6 @@
)
from ert.storage.realization_storage_state import RealizationStorageState

from ._write_transaction import write_transaction

if TYPE_CHECKING:
from ert.config import ParameterConfig, ResponseConfig

Expand Down Expand Up @@ -73,6 +73,7 @@ class LocalStorage(BaseMode):
LOCK_TIMEOUT = 5
EXPERIMENTS_PATH = "experiments"
ENSEMBLES_PATH = "ensembles"
SWAP_PATH = "swp"

def __init__(
self,
Expand Down Expand Up @@ -250,6 +251,10 @@ def _ensemble_path(self, ensemble_id: UUID) -> Path:
def _experiment_path(self, experiment_id: UUID) -> Path:
return self.path / self.EXPERIMENTS_PATH / str(experiment_id)

@cached_property
def _swap_path(self) -> Path:
return self.path / self.SWAP_PATH

def __enter__(self) -> LocalStorage:
return self

Expand Down Expand Up @@ -448,7 +453,7 @@ def _add_migration_information(

@require_write
def _save_index(self) -> None:
write_transaction(
self._write_transaction(
self.path / "index.json",
self._index.model_dump_json(indent=4).encode("utf-8"),
)
Expand Down Expand Up @@ -550,6 +555,18 @@ def get_unique_experiment_name(self, experiment_name: str) -> str:
else:
return experiment_name + "_0"

def _write_transaction(self, filename: str | os.PathLike[str], data: bytes) -> None:
"""
Writes the data to the filename as a transaction.
Guarantees to not leave half-written or empty files on disk if the write
fails or the process is killed.
"""
self._swap_path.mkdir(parents=True, exist_ok=True)
with NamedTemporaryFile(dir=self._swap_path, delete=False) as f:
f.write(data)
os.rename(f.name, filename)


def _storage_version(path: Path) -> int:
if not path.exists():
Expand Down
20 changes: 19 additions & 1 deletion tests/ert/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import pytest
import xarray as xr
from hypothesis import assume
from hypothesis import assume, given
from hypothesis.extra.numpy import arrays
from hypothesis.stateful import Bundle, RuleBasedStateMachine, initialize, rule

Expand Down Expand Up @@ -483,6 +483,24 @@ def fields(draw, egrid, num_fields=small_ints) -> List[Field]:
]


@pytest.mark.usefixtures("use_tmpdir")
@given(st.binary())
def test_write_transaction(data):
with open_storage(".", "w") as storage:
filepath = Path("./file.txt")
storage._write_transaction(filepath, data)

assert filepath.read_bytes() == data


def test_write_transaction_overwrites(tmp_path):
with open_storage(tmp_path, "w") as storage:
path = tmp_path / "file.txt"
path.write_text("abc")
storage._write_transaction(path, b"deadbeaf")
assert path.read_bytes() == b"deadbeaf"


@dataclass
class Ensemble:
uuid: UUID
Expand Down
46 changes: 0 additions & 46 deletions tests/ert/unit_tests/storage/test_write_transaction.py

This file was deleted.

0 comments on commit 31af4fe

Please sign in to comment.