diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0033da8..8892dea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -60,6 +60,7 @@ jobs: name: coverage-data-${{ matrix.python-version }} path: .coverage.* if-no-files-found: ignore + include-hidden-files: true Coverage: needs: Pytest diff --git a/changelog.md b/changelog.md index 4c74ffa..c964610 100644 --- a/changelog.md +++ b/changelog.md @@ -1,8 +1,9 @@ # Changelog -## Unreleased +## v0.3.5 - Support hashing the `folded_tensor.length` field (via a UserList), which is convenient for caching +- Improve error messaging when refolding with missing dims ## v0.3.4 diff --git a/foldedtensor/__init__.py b/foldedtensor/__init__.py index e3eb63f..10bf098 100644 --- a/foldedtensor/__init__.py +++ b/foldedtensor/__init__.py @@ -45,7 +45,7 @@ except AttributeError: DisableTorchFunctionSubclass = torch._C.DisableTorchFunction -__version__ = "0.3.4" +__version__ = "0.3.5" class FoldedTensorLengths(UserList): @@ -402,9 +402,16 @@ def refold(self, *dims: Union[Sequence[Union[int, str]], int, str]): "sequence or each arguments to be ints or strings" ) dims = dims[0] - dims = tuple( - dim if isinstance(dim, int) else self.full_names.index(dim) for dim in dims - ) + try: + dims = tuple( + dim if isinstance(dim, int) else self.full_names.index(dim) + for dim in dims + ) + except ValueError: + raise ValueError( + f"Folded tensor with available dimensions {self.full_names} " + f"could not be refolded with dimensions {list(dims)}" + ) if dims == self.data_dims: return self diff --git a/tests/test_folded_tensor.py b/tests/test_folded_tensor.py index c8f9e13..a4d2e24 100644 --- a/tests/test_folded_tensor.py +++ b/tests/test_folded_tensor.py @@ -431,3 +431,18 @@ def test_hashable_lengths(): assert tensor.lengths is embedding(tensor).lengths assert hash(tensor.lengths) is not None assert hash(tensor.lengths) == hash(embedding(tensor).lengths) + + +def test_missing_dims(): + tensor = as_folded_tensor( + [ + [0, 1, 2], + [3, 4], + ], + full_names=("sample", "token"), + dtype=torch.long, + ) + with pytest.raises(ValueError) as e: + tensor.refold("line", "token") + + assert "line" in str(e.value)