Skip to content

Commit

Permalink
feat: hashable lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Sep 1, 2024
1 parent acdadc9 commit 4e790f6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 12 deletions.
24 changes: 15 additions & 9 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,42 +1,48 @@
# v0.3.4
# Changelog

## Unreleased

- Support hashing the `folded_tensor.length` field (via a UserList), which is convenient for caching

## v0.3.4

- Fix a data_dims access issue
- Marginally improve the speed of handling FoldedTensors in standard torch operations
- Use default torch types (e.g. `torch.float32` or `torch.torch64`)

# v0.3.3
## v0.3.3

- Handle empty inputs (e.g. `as_folded_tensor([[[], []], [[]]])`) by returning an empty tensor
- Correctly bubble errors when converting inputs with varying deepness (e.g. `as_folded_tensor([1, [2, 3]])`)

# v0.3.2
## v0.3.2

- Allow to use `as_folded_tensor` with no args, as a simple padding function

# v0.3.1
## v0.3.1

- Enable sharing FoldedTensor instances in a multiprocessing + cuda context by autocloning the indexer before fork-pickling an instance
- Distribute arm64 wheels for macOS

# v0.3.0
## v0.3.0

- Allow dims after last foldable dim during list conversion (e.g. embeddings)

# v0.2.2
## v0.2.2

- Github release :octocat:
- Fix backpropagation when refolding

# v0.2.1
## v0.2.1

- Improve performance by computing the new "padded to flattened" indexer only (and not the previous one) when refolding

# v0.2.0
## v0.2.0

- Remove C++ torch dependency in favor of Numpy due to lack of torch ABI backward/forward compatibility, making the pre-built wheels unusable in most cases
- Require dtype to be specified when creating a FoldedTensor from a nested list

# v0.1.0
## v0.1.0

Inception ! :tada:

Expand Down
17 changes: 14 additions & 3 deletions foldedtensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing
from collections import UserList
from multiprocessing.reduction import ForkingPickler
from typing import List, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -46,6 +48,15 @@
__version__ = "0.3.4"


class FoldedTensorLengths(UserList):
def __hash__(self):
return id(self)


if typing.TYPE_CHECKING:
FoldedTensorLengths = List[List[int]] # noqa: F811


# noinspection PyMethodOverriding
class Refold(Function):
@staticmethod
Expand Down Expand Up @@ -179,7 +190,7 @@ def as_folded_tensor(
)
result = FoldedTensor(
data=data,
lengths=lengths,
lengths=FoldedTensorLengths(lengths),
data_dims=data_dims,
full_names=full_names,
indexer=torch.from_numpy(np_indexer).to(data.device),
Expand Down Expand Up @@ -207,7 +218,7 @@ def as_folded_tensor(
lengths = (list(lengths) + [[0]] * deepness)[:deepness]
result = FoldedTensor(
data=padded,
lengths=lengths,
lengths=FoldedTensorLengths(lengths),
data_dims=data_dims,
full_names=full_names,
indexer=indexer,
Expand Down Expand Up @@ -269,7 +280,7 @@ class FoldedTensor(torch.Tensor):
def __new__(
cls,
data: torch.Tensor,
lengths: List[List[int]],
lengths: FoldedTensorLengths,
data_dims: Sequence[int],
full_names: Sequence[str],
indexer: torch.Tensor,
Expand Down
14 changes: 14 additions & 0 deletions tests/test_folded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,17 @@ def test_max():
values, indices = ft.max(-1)
assert (values == torch.tensor([2, 4])).all()
assert (indices == torch.tensor([2, 1])).all()


def test_hashable_lengths():
tensor = as_folded_tensor(
[
[0, 1, 2],
[3, 4],
],
dtype=torch.float,
)
embedding = torch.nn.Embedding(10, 16)
assert tensor.lengths is embedding(tensor).lengths
assert hash(tensor.lengths) is not None
assert hash(tensor.lengths) == hash(embedding(tensor).lengths)

0 comments on commit 4e790f6

Please sign in to comment.