Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] RuntimeError when using tensor_split with CUDA tensors in segment_mm #58

Open
theo-barfoot opened this issue Oct 21, 2024 · 2 comments
Assignees

Comments

@theo-barfoot
Copy link
Collaborator

RuntimeError when using tensor_split with CUDA tensors in segment_mm

Description

When running the tests for segment_mm, I encountered a RuntimeError related to the use of tensor_split with CUDA tensors. Specifically, tensor_split expects the tensor_indices_or_sections argument to be on the CPU, but in this case, it is on CUDA.

Steps to Reproduce

  1. Use a machine with CUDA enabled.

  2. Run the following test:

    pytest torchsparsegradutils/tests/test_indexed_matmul.py
  3. Example failure:

E       RuntimeError: tensor_split expected tensor_indices_or_sections to be on cpu, but it's on cuda:0

Relevant Code

The error occurs in the following section of segment_mm:

nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0))

The issue arises because segidx_a is created on the GPU (CUDA), but torch.tensor_split requires it to be on the CPU.

Proposed Solution

To resolve this, the segidx_a tensor should be explicitly moved to the CPU before passing it to torch.tensor_split. Here's a possible fix:

segidx_a = torch.cumsum(seglen_a[:-1].cpu(), dim=0)  # Ensure segidx_a is on the CPU
nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0))

This ensures compatibility with the expected behavior of tensor_split and avoids the RuntimeError.

Impact

This affects users running the code on CUDA-enabled devices, but the issue does not occur on CPU-only environments.

@theo-barfoot theo-barfoot self-assigned this Oct 21, 2024
@tvercaut
Copy link
Member

I thought I tested it with cuda but I may have done it for gather_mm only... Anyway, that looks fine by me. I guess you could just move segidx_a to the cpu instead of seglen_a:

segidx_a = torch.cumsum(seglen_a[:-1], dim=0).cpu()  # Ensure segidx_a is on the CPU

@tvercaut
Copy link
Member

Actually, cumsum may be more suitable to cpu so your initial suggestion could be better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants