diff --git a/src/ovds_utils/ovds/enums.py b/src/ovds_utils/ovds/enums.py index cec155d..7f8f8ff 100644 --- a/src/ovds_utils/ovds/enums.py +++ b/src/ovds_utils/ovds/enums.py @@ -6,6 +6,7 @@ class InitValue(Enum): NaN = auto() zero = auto() + omit_init = auto() class Formats(Enum): diff --git a/src/ovds_utils/ovds/writing.py b/src/ovds_utils/ovds/writing.py index 9f86d5a..6425158 100644 --- a/src/ovds_utils/ovds/writing.py +++ b/src/ovds_utils/ovds/writing.py @@ -1,3 +1,4 @@ +from logging import getLogger from typing import Any, AnyStr, Dict, List import numpy as np @@ -6,6 +7,9 @@ from .enums import InitValue from .utils import copy_ovds_metadata + +logger = getLogger(__name__) + FORMAT2NPTYPE = { openvds.VolumeDataChannelDescriptor.Format.Format_R64: np.float64, openvds.VolumeDataChannelDescriptor.Format.Format_R32: np.float32, @@ -133,7 +137,7 @@ def create_vds( default_max_pages: int = 8, channels_data=None, init_value: InitValue = InitValue.zero, - close=True + # close: bool = True ): ( layout_descriptor, @@ -200,8 +204,12 @@ def create_vds( channel=i, maxPages=default_max_pages, ) - INITVALUE[init_value](accessor, channel.format.value) + if init_value != InitValue.omit_init: + INITVALUE[init_value](accessor, channel.format.value) + else: + logger.warning( + "Values initialization was omitted. Remember to do it and close vds resource properly." + ) + return vds - if close: - openvds.close(vds) return vds diff --git a/src/ovds_utils/vds.py b/src/ovds_utils/vds.py index 37378d4..b4b4095 100644 --- a/src/ovds_utils/vds.py +++ b/src/ovds_utils/vds.py @@ -1,5 +1,6 @@ from copy import deepcopy from typing import AnyStr, List, Sequence, Tuple, Union +from enum import Enum import numpy as np import openvds @@ -227,6 +228,12 @@ def commit(self): self.accessor.commit() +class Mode(str, Enum): + READ = "READ" + READ_WRITE = "READ_WRITE" + WRITE = "WRITE" + + class VDS: def __init__( self, @@ -243,46 +250,59 @@ def __init__( options: Options = Options._None, full_resolution_dimension: int = 0, brick_size_2d_multiplier: int = 4, - init_value: InitValue = InitValue.zero + init_value: InitValue = InitValue.zero, + mode: Mode = Mode.READ ) -> None: super().__init__() - self.path = path - self.connection_string = connection_string - self._channels = {} - self._axes = {} - try: - self._vds_source = openvds.open(path, connection_string) - vds_info = get_vds_info(path, connection_string) - except RuntimeError as e: - if str(e) in ("Open error: File::open "): - logger.debug("Creating new VDS source...") - self.create( - path=path, - connection_string=connection_string, - databrick_size=databrick_size, - channels=channels, - axes=axes, - metadata_dict=metadata_dict, - channels_data=channels_data, - access_mode=AccessModes.Create, - lod=lod, - positive_margin=positive_margin, - negagitve_margin=negative_margin, - options=options, - full_resolution_dimension=full_resolution_dimension, - brick_size_2d_multiplier=brick_size_2d_multiplier, - init_value=init_value - ) - self = self.__init__( + if mode in (Mode.READ_WRITE, Mode.READ): + try: + self._vds_source = openvds.open(path, connection_string) + self.initialize( path=path, connection_string=connection_string, databrick_size=databrick_size, - metadata_dict=metadata_dict, ) - return - else: - raise VDSException(f"Open VDS resulted with: {str(e)}") from e + except RuntimeError as e: + if str(e) == "Open error: File::open ": + raise VDSException(f"Could not open vds for path {path}") + else: + logger.debug("Creating new VDS source...") + self._vds_source = self.create( + path=path, + connection_string=connection_string, + databrick_size=databrick_size, + channels=channels, + axes=axes, + metadata_dict=metadata_dict, + channels_data=channels_data, + access_mode=AccessModes.Create, + lod=lod, + positive_margin=positive_margin, + negagitve_margin=negative_margin, + options=options, + full_resolution_dimension=full_resolution_dimension, + brick_size_2d_multiplier=brick_size_2d_multiplier, + init_value=init_value + ) + self.initialize( + path=path, + connection_string=connection_string, + databrick_size=databrick_size, + ) + + def initialize( + self, + path: AnyStr, + connection_string: AnyStr = "", + databrick_size: BrickSizes = BrickSizes._128, + ): + self.closed = False + self.path = path + self.connection_string = connection_string + self._channels = {} + self._axes = {} + self.closed = False self._layout = openvds.getLayout(self._vds_source) self._dimensionality = self._layout.getDimensionality() @@ -329,14 +349,19 @@ def axes(self) -> List[Axis]: def get_channel(self, name: AnyStr) -> Channel: return self._channels[name] + def close(self, flush: bool = True): + if not getattr(self, "closed", True): + openvds.close(self._vds_source, flush) + self.closed = True + def __del__(self): - openvds.close(self._vds_source) + self.close() def __enter__(self): return self def __exit__(self, *args, **kwargs): - return + self.close() @property def metadata(self) -> MetadataContainer: @@ -422,7 +447,7 @@ def create( brick_size_2d_multiplier=brick_size_2d_multiplier, full_resolution_dimension=full_resolution_dimension, init_value=init_value, - close=True + # close=True ) @property diff --git a/tests/test_vds.py b/tests/test_vds.py index 8f2f23f..679d4ae 100644 --- a/tests/test_vds.py +++ b/tests/test_vds.py @@ -5,7 +5,7 @@ from ovds_utils.metadata import MetadataTypes, MetadataValue from ovds_utils.ovds.enums import BrickSizes, Formats -from ovds_utils.vds import VDS, Axis, Channel, Components +from ovds_utils.vds import VDS, Axis, Channel, Components, Mode, InitValue def test_vds_shape(): @@ -40,7 +40,8 @@ def test_vds_shape(): channels_data=[ data ], - databrick_size=BrickSizes._128 + databrick_size=BrickSizes._128, + mode=Mode.WRITE ) assert vds[:, :, :].shape == vds.shape == shape @@ -85,7 +86,8 @@ def test_create_vds_by_chunks(): value_range_max=1000.0, components=Components._1 ) - ] + ], + mode=Mode.WRITE ) for chunk in vds.channel(0).chunks(): chunk[:, :, :] = data[chunk.slices] @@ -133,8 +135,70 @@ def test_vds_3d_cube_default_axis_name(): components=Components._1 ) ], - databrick_size=BrickSizes._128 + databrick_size=BrickSizes._128, + mode=Mode.WRITE ) vds.axis_descriptors[0] == ("Sample", "unitless", 126) vds.axis_descriptors[1] == ("Crossline", "unitless", 51) vds.axis_descriptors[2] == ("Inline", "unitless", 251) + + +def test_vds_3d_cube_initialize_and_wirte_later(): + shape = (251, 51, 126) + dtype = np.float64 + data = np.random.rand(*shape).astype(dtype) + metadata = { + "example": MetadataValue(value="value", category="category#1", type=MetadataTypes.String) + } + data = np.random.rand(*shape).astype(dtype) + names = ["Sample", "Crossline", "Inline"] + axes = [ + Axis( + samples=s, + name=names[i], + unit="unitless", + coordinate_max=1000.0, + coordinate_min=-1000.0 + ) + for i, s in enumerate(shape) + ] + with TemporaryDirectory() as dir: + with VDS( + os.path.join(dir, "example.vds"), + metadata_dict=metadata, + axes=axes, + init_value=InitValue.omit_init, + databrick_size=BrickSizes._64, + channels=[ + Channel( + name="Channel0", + format=Formats.R64, + unit="unitless", + value_range_min=0.0, + value_range_max=1000.0, + components=Components._1 + ) + ], + mode=Mode.WRITE + ) as vds: + pass + + with VDS( + os.path.join(dir, "example.vds"), + mode=Mode.READ_WRITE + ) as vds: + for chunk in vds.channel(0).chunks(): + chunk[:, :, :] = data[chunk.slices] + chunk.release() + vds.channel(0).commit() + + with VDS( + os.path.join(dir, "example.vds"), + mode=Mode.READ + ) as vds: + assert vds[:, :, :].shape == vds.shape == shape + for _ in range(shape[0]): + assert all( + np.array_equal(data[i, 0, :], vds[i, 0, :]) + for i in range(shape[0]) + )