diff --git a/torchsparsegradutils/indexed_matmul.py b/torchsparsegradutils/indexed_matmul.py index ef0c8f7..b41535c 100644 --- a/torchsparsegradutils/indexed_matmul.py +++ b/torchsparsegradutils/indexed_matmul.py @@ -45,7 +45,7 @@ def segment_mm(a, b, seglen_a): if not a.shape[1] == D1 or not seglen_a.shape[0] == R: raise ValueError("Incompatible size for inputs") - segidx_a = torch.cumsum(seglen_a[:-1], dim=0) + segidx_a = torch.cumsum(seglen_a[:-1], dim=0).cpu() # Ideally the conversions below to nested tensor would be handled natively nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0))