Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Support safetensors in tensor adapters #1933

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 66 additions & 3 deletions onnxscript/ir/tensor_adapters.py
Original file line number Diff line number Diff line change
@@ -25,11 +25,15 @@
# pylint: disable=import-outside-toplevel

# NOTE: DO NOT import any framework-specific modules here in the global namespace.
# NOTE: We use ir.DataType instead of _enums.DataType to show users how they
# should create custom tensor adapters. This is fine and will not create
# circular imports because the ir.DataType's are not used in the global namespace.

from __future__ import annotations

__all__ = [
"TorchTensor",
"SafetensorsTensor",
]

import ctypes
@@ -46,8 +50,12 @@

class TorchTensor(_core.Tensor):
def __init__(
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
):
self,
tensor: torch.Tensor,
name: str | None = None,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
# Pass the tensor as the raw data to ir.Tensor's constructor
import torch

@@ -73,7 +81,11 @@
torch.uint64: ir.DataType.UINT64,
}
super().__init__(
tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
tensor,
dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype],
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)

def numpy(self) -> npt.NDArray:
@@ -120,3 +132,54 @@
tensor.data_ptr()
)
)


class SafetensorsTensor(_core.Tensor):
"""Adaptor for Hugging Face's [safetensors](https://github.com/huggingface/safetensors) library.

This adaptor allows you to load tensors from a safetensors file in a
memory-efficient way and use them in the ONNX IR. The tensor is memory-mapped.
"""
def __init__(
self,
path: str,
tensor_name: str,
/,
dtype: ir.DataType | None = None,
*,
shape: ir.Shape | None = None,
name: str | None = None,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
"""Create a tensor from a tensor stored in a SafeTensors file.

Args:
path: The path to the SafeTensors file.
tensor_name: The name of the tensor in the SafeTensors file.
dtype: The data type of the tensor. It can be specified if the value
is not of a standard NumPy dtype.
shape: The shape of the tensor. It can be specified if the value
is not of a standard NumPy dtype.
name: The name of the ONNX tensor.
doc_string: The documentation string for the tensor.
metadata_props: The metadata properties for the tensor.
"""
import safetensors

Check warning on line 168 in onnxscript/ir/tensor_adapters.py

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L168

Added line #L168 was not covered by tests

self._path = path
self._tensor_name = tensor_name

Check warning on line 171 in onnxscript/ir/tensor_adapters.py

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L170-L171

Added lines #L170 - L171 were not covered by tests

with safetensors.safe_open(path, framework="numpy") as f:

Check warning on line 173 in onnxscript/ir/tensor_adapters.py

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L173

Added line #L173 was not covered by tests

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

"safe_open" has no attribute "__enter__" To disable, use # type: ignore[attr-defined]

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

"safe_open" has no attribute "__exit__" To disable, use # type: ignore[attr-defined]
# The tensor is mmap'ed in memory so we might as well load it
# at initialization time since it does not take up any extra memory
array = f.get_tensor(tensor_name)

Check warning on line 176 in onnxscript/ir/tensor_adapters.py

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L176

Added line #L176 was not covered by tests

super().__init__(

Check warning on line 178 in onnxscript/ir/tensor_adapters.py

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L178

Added line #L178 was not covered by tests
array,
dtype=dtype,
shape=shape,
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)