You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
RuntimeError when using tensor_split with CUDA tensors in segment_mm
Description
When running the tests for segment_mm, I encountered a RuntimeError related to the use of tensor_split with CUDA tensors. Specifically, tensor_split expects the tensor_indices_or_sections argument to be on the CPU, but in this case, it is on CUDA.
The issue arises because segidx_a is created on the GPU (CUDA), but torch.tensor_split requires it to be on the CPU.
Proposed Solution
To resolve this, the segidx_a tensor should be explicitly moved to the CPU before passing it to torch.tensor_split. Here's a possible fix:
segidx_a=torch.cumsum(seglen_a[:-1].cpu(), dim=0) # Ensure segidx_a is on the CPUnested_a=torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0))
This ensures compatibility with the expected behavior of tensor_split and avoids the RuntimeError.
Impact
This affects users running the code on CUDA-enabled devices, but the issue does not occur on CPU-only environments.
The text was updated successfully, but these errors were encountered:
I thought I tested it with cuda but I may have done it for gather_mm only... Anyway, that looks fine by me. I guess you could just move segidx_a to the cpu instead of seglen_a:
segidx_a=torch.cumsum(seglen_a[:-1], dim=0).cpu() # Ensure segidx_a is on the CPU
RuntimeError when using
tensor_split
with CUDA tensors insegment_mm
Description
When running the tests for
segment_mm
, I encountered aRuntimeError
related to the use oftensor_split
with CUDA tensors. Specifically,tensor_split
expects thetensor_indices_or_sections
argument to be on the CPU, but in this case, it is on CUDA.Steps to Reproduce
Use a machine with CUDA enabled.
Run the following test:
Example failure:
Relevant Code
The error occurs in the following section of
segment_mm
:The issue arises because
segidx_a
is created on the GPU (CUDA), buttorch.tensor_split
requires it to be on the CPU.Proposed Solution
To resolve this, the
segidx_a
tensor should be explicitly moved to the CPU before passing it totorch.tensor_split
. Here's a possible fix:This ensures compatibility with the expected behavior of
tensor_split
and avoids the RuntimeError.Impact
This affects users running the code on CUDA-enabled devices, but the issue does not occur on CPU-only environments.
The text was updated successfully, but these errors were encountered: