Skip to content

Commit

Permalink
Handle right
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jan 22, 2025
1 parent 7dfdfdf commit f5e0724
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 19 deletions.
26 changes: 26 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
"_metadata_props",
"_offset",
"_shape",
"_valid",
"doc_string",
"name",
"raw",
Expand Down Expand Up @@ -568,6 +569,7 @@ def __init__(
self.raw: mmap.mmap | None = None
self._metadata_props = metadata_props
self._metadata: _metadata.MetadataStore | None = None
self._valid = True

@property
def base_dir(self) -> str | os.PathLike:
Expand Down Expand Up @@ -609,6 +611,7 @@ def shape(self) -> Shape:
return self._shape

def _load(self):
self._check_validity()
assert self._array is None, "Bug: The array should be loaded only once."
if self.size == 0:
# When the size is 0, mmap is impossible and meaningless
Expand Down Expand Up @@ -647,6 +650,7 @@ def _load(self):
self._array = self._array.reshape(shape)

def __array__(self, dtype: Any = None) -> np.ndarray:
self._check_validity()
if self._array is None:
self._load()
assert self._array is not None
Expand Down Expand Up @@ -675,6 +679,7 @@ def numpy(self) -> np.ndarray:
The data will be memory mapped into memory and will not taken up physical memory space.
"""
self._check_validity()
if self._array is None:
self._load()
assert self._array is not None
Expand All @@ -685,13 +690,34 @@ def tobytes(self) -> bytes:
This will load the tensor into memory.
"""
self._check_validity()
if self.raw is None:
self._load()
assert self.raw is not None
offset = self._offset or 0
length = self._length or self.nbytes
return self.raw[offset : offset + length]

def valid(self) -> bool:
"""Check if the tensor is valid.
The external tensor is valid if it has not been invalidated.
"""
return self._valid

def _check_validity(self) -> None:
if not self.valid():
raise ValueError(
f"The external tensor '{self!r}' is invalidated. The data may be corrupted or deleted."
)

def invalidate(self) -> None:
"""Invalidate the tensor.
The external tensor is invalidated when the data is known to be corrupted or deleted.
"""
self._valid = False

def release(self) -> None:
"""Delete all references to the memory buffer and close the memory-mapped file."""
self._array = None
Expand Down
13 changes: 9 additions & 4 deletions onnxscript/ir/_external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ def convert_tensors_to_external(
# is referring to this file, that tensor is now invalid.
# This is a special case we are ok not handling right now.
new_tensors.append(_external_tensor_to_memory_tensor(tensor))
# Mark the original external tensor as invalid because it is now pointing
# to a file that is going to be overwritten.
tensor.invalidate()
else:
new_tensors.append(tensor)
tensors = new_tensors
Expand Down Expand Up @@ -312,7 +315,7 @@ def to_external_data(
An ir.Model with all initializer data equal or above :param:`size_threshold_bytes`
converted to external tensors.
"""
# In-memory or external tensors, if above the threshold, should be converted to or re-saved as external tensors
# In-memory or external tensors, if equal to or above the threshold, should be converted to or re-saved as external tensors
initializers_to_become_external = []
# Existing external tensors, if below the threshold, should be loaded to memory
initializers_to_load_to_memory = []
Expand All @@ -321,16 +324,18 @@ def to_external_data(
# Filter out the uninitialized initializer values
continue
if value.const_value.nbytes > size_threshold_bytes:
initializers_to_become_external.append(value.const_value)
initializers_to_become_external.append(value)
elif isinstance(value.const_value, _core.ExternalTensor):
initializers_to_load_to_memory.append(value.const_value)
initializers_to_load_to_memory.append(value)

# Load to memory first, then convert to external tensors, because
# the existing external tensors may be overwritten by the new external data
memory_tensors = convert_tensors_from_external(initializers_to_load_to_memory)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "convert_tensors_from_external" has incompatible type "list[Value]"; expected "Sequence[ExternalTensor]" To disable, use # type: ignore[arg-type]
external_tensors = convert_tensors_to_external(
[v.const_value for v in initializers_to_become_external],

Check failure

Code scanning / lintrunner

MYPY/misc Error

List comprehension has incompatible type List[TensorProtocol | None]; expected List[TensorProtocol] To disable, use # type: ignore[misc]
base_dir=base_dir,
relative_path=relative_path,
)
memory_tensors = convert_tensors_from_external(initializers_to_load_to_memory)

# Replace the initializer values with external tensors and save the model
assert len(initializers_to_become_external) == len(external_tensors)
Expand Down
37 changes: 22 additions & 15 deletions onnxscript/ir/_io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import unittest

import numpy as np
import parameterized

from onnxscript import ir
from onnxscript.ir import _io
Expand Down Expand Up @@ -77,7 +76,7 @@ def test_save_with_external_data_does_not_modify_model(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "model.onnx")
external_data_file = "model.data"
_io.save(model, path, external_data=external_data_file)
_io.save(model, path, external_data=external_data_file, size_threshold_bytes=0)
self.assertTrue(os.path.exists(path))
external_data_path = os.path.join(tmpdir, external_data_file)
self.assertTrue(os.path.exists(external_data_path))
Expand Down Expand Up @@ -109,27 +108,35 @@ def test_save_raise_when_external_data_is_not_relative_path(self):
with self.assertRaises(ValueError):
_io.save(model, path, external_data=external_data_file)

def test_save_with_external_data_invalidates_obsolete_external_tensors(
self, _: str
):
def test_save_with_external_data_invalidates_obsolete_external_tensors(self):
model = _create_simple_model_with_initializers()
self.assertIsInstance(model.graph.initializers["initializer_0"].const_value, ir.Tensor)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "model.onnx")
external_data_file = "model.data"
_io.save(model, path, external_data=external_data_file)
# The original model is modified
initializer_tensor = model.graph.initializers["initializer_0"].const_value
self.assertIsInstance(initializer_tensor, ir.ExternalTensor)

# Now if we create a different initializer and save that model with the same external data file
_io.save(model, path, external_data=external_data_file, size_threshold_bytes=0)
loaded_model = _io.load(path)
# Now if we load the model back, create a different initializer and save
# the model to the same external data file, the existing external tensor
# should be invalidated
tensor_2 = ir.tensor([2.0], dtype=ir.DataType.FLOAT, name="initializer_2")
initializer_2 = _create_initializer(tensor_2)
model.graph.initializers["initializer_2"] = initializer_2
with self.assertRaises(ValueError):
loaded_model.graph.initializers["initializer_2"] = initializer_2
_io.save(
loaded_model, path, external_data=external_data_file, size_threshold_bytes=0
)
initializer_0_tensor = loaded_model.graph.initializers["initializer_0"].const_value
self.assertIsInstance(initializer_0_tensor, ir.ExternalTensor)
self.assertFalse(initializer_0_tensor.valid())
with self.assertRaisesRegex(ValueError, "is invalidated"):
# The existing model has to be modified to use in memory tensors
# for the values to stay correct
_io.save(model, path, external_data=external_data_file)
# for the values to stay correct. Saving again should raise an error
_io.save(
loaded_model,
path,
external_data=external_data_file,
size_threshold_bytes=0,
)


if __name__ == "__main__":
Expand Down

0 comments on commit f5e0724

Please sign in to comment.